mirror of
https://github.com/nod-ai/SHARK-Studio.git
synced 2026-04-20 03:00:34 -04:00
Compare commits
1 Commits
20230722.8
...
20230719.8
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
29963c24ca |
@@ -2,7 +2,7 @@ import copy
|
||||
import torch
|
||||
|
||||
from evaluate_params import eval_func_param_names
|
||||
from gen import Langchain
|
||||
from gen import get_score_model, get_model, evaluate, check_locals
|
||||
from prompter import non_hf_types
|
||||
from utils import clear_torch_cache, NullContext, get_kwargs
|
||||
|
||||
@@ -87,7 +87,7 @@ def run_cli( # for local function:
|
||||
# unique to this function:
|
||||
cli_loop=None,
|
||||
):
|
||||
Langchain.check_locals(**locals())
|
||||
check_locals(**locals())
|
||||
|
||||
score_model = "" # FIXME: For now, so user doesn't have to pass
|
||||
n_gpus = torch.cuda.device_count() if torch.cuda.is_available else 0
|
||||
@@ -98,20 +98,16 @@ def run_cli( # for local function:
|
||||
from functools import partial
|
||||
|
||||
# get score model
|
||||
smodel, stokenizer, sdevice = Langchain.get_score_model(
|
||||
smodel, stokenizer, sdevice = get_score_model(
|
||||
reward_type=True,
|
||||
**get_kwargs(
|
||||
Langchain.get_score_model,
|
||||
exclude_names=["reward_type"],
|
||||
**locals()
|
||||
get_score_model, exclude_names=["reward_type"], **locals()
|
||||
)
|
||||
)
|
||||
|
||||
model, tokenizer, device = Langchain.get_model(
|
||||
model, tokenizer, device = get_model(
|
||||
reward_type=False,
|
||||
**get_kwargs(
|
||||
Langchain.get_model, exclude_names=["reward_type"], **locals()
|
||||
)
|
||||
**get_kwargs(get_model, exclude_names=["reward_type"], **locals())
|
||||
)
|
||||
model_dict = dict(
|
||||
base_model=base_model,
|
||||
@@ -125,11 +121,11 @@ def run_cli( # for local function:
|
||||
model_state.update(model_dict)
|
||||
my_db_state = [None]
|
||||
fun = partial(
|
||||
Langchain.evaluate,
|
||||
evaluate,
|
||||
model_state,
|
||||
my_db_state,
|
||||
**get_kwargs(
|
||||
Langchain.evaluate,
|
||||
evaluate,
|
||||
exclude_names=["model_state", "my_db_state"]
|
||||
+ eval_func_param_names,
|
||||
**locals()
|
||||
|
||||
@@ -7,7 +7,7 @@ import torch
|
||||
from matplotlib import pyplot as plt
|
||||
|
||||
from evaluate_params import eval_func_param_names, eval_extra_columns
|
||||
from gen import Langchain
|
||||
from gen import get_context, get_score_model, get_model, evaluate, check_locals
|
||||
from prompter import Prompter
|
||||
from utils import clear_torch_cache, NullContext, get_kwargs
|
||||
|
||||
@@ -94,7 +94,7 @@ def run_eval( # for local function:
|
||||
force_langchain_evaluate=None,
|
||||
model_state_none=None,
|
||||
):
|
||||
Langchain.check_locals(**locals())
|
||||
check_locals(**locals())
|
||||
|
||||
if eval_prompts_only_num > 0:
|
||||
np.random.seed(eval_prompts_only_seed)
|
||||
@@ -144,7 +144,7 @@ def run_eval( # for local function:
|
||||
] = "" # no input
|
||||
examplenew[
|
||||
eval_func_param_names.index("context")
|
||||
] = Langchain.get_context(chat_context, prompt_type)
|
||||
] = get_context(chat_context, prompt_type)
|
||||
examples.append(examplenew)
|
||||
responses.append(output)
|
||||
else:
|
||||
@@ -170,7 +170,7 @@ def run_eval( # for local function:
|
||||
] = "" # no input
|
||||
examplenew[
|
||||
eval_func_param_names.index("context")
|
||||
] = Langchain.get_context(chat_context, prompt_type)
|
||||
] = get_context(chat_context, prompt_type)
|
||||
examples.append(examplenew)
|
||||
responses.append(output)
|
||||
|
||||
@@ -210,22 +210,18 @@ def run_eval( # for local function:
|
||||
from functools import partial
|
||||
|
||||
# get score model
|
||||
smodel, stokenizer, sdevice = Langchain.get_score_model(
|
||||
smodel, stokenizer, sdevice = get_score_model(
|
||||
reward_type=True,
|
||||
**get_kwargs(
|
||||
Langchain.get_score_model,
|
||||
exclude_names=["reward_type"],
|
||||
**locals()
|
||||
get_score_model, exclude_names=["reward_type"], **locals()
|
||||
)
|
||||
)
|
||||
|
||||
if not eval_as_output:
|
||||
model, tokenizer, device = Langchain.get_model(
|
||||
model, tokenizer, device = get_model(
|
||||
reward_type=False,
|
||||
**get_kwargs(
|
||||
Langchain.get_model,
|
||||
exclude_names=["reward_type"],
|
||||
**locals()
|
||||
get_model, exclude_names=["reward_type"], **locals()
|
||||
)
|
||||
)
|
||||
model_dict = dict(
|
||||
@@ -240,11 +236,11 @@ def run_eval( # for local function:
|
||||
model_state.update(model_dict)
|
||||
my_db_state = [None]
|
||||
fun = partial(
|
||||
Langchain.evaluate,
|
||||
evaluate,
|
||||
model_state,
|
||||
my_db_state,
|
||||
**get_kwargs(
|
||||
Langchain.evaluate,
|
||||
evaluate,
|
||||
exclude_names=["model_state", "my_db_state"]
|
||||
+ eval_func_param_names,
|
||||
**locals()
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -34,7 +34,7 @@ from enums import (
|
||||
LangChainMode,
|
||||
)
|
||||
from evaluate_params import gen_hyper
|
||||
from gen import Langchain, SEED
|
||||
from gen import get_model, SEED
|
||||
from prompter import non_hf_types, PromptType, Prompter
|
||||
from utils import (
|
||||
wrapped_partial,
|
||||
@@ -907,7 +907,7 @@ def get_llm(
|
||||
# model_name = 'h2oai/h2ogpt-oig-oasst1-512-6_9b'
|
||||
# model_name = 'h2oai/h2ogpt-oasst1-512-20b'
|
||||
inference_server = ""
|
||||
model, tokenizer, device = Langchain.get_model(
|
||||
model, tokenizer, device = get_model(
|
||||
load_8bit=True,
|
||||
base_model=model_name,
|
||||
inference_server=inference_server,
|
||||
|
||||
@@ -34,22 +34,20 @@ class H2OGPTSHARKModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
model_name = "h2ogpt_falcon_7b"
|
||||
extended_model_name = (
|
||||
model_name + "_" + args.precision + "_" + args.device
|
||||
path_str = (
|
||||
model_name + "_" + args.precision + "_" + args.device + ".vmfb"
|
||||
)
|
||||
vmfb_path = Path(extended_model_name + ".vmfb")
|
||||
mlir_path = Path(model_name + "_" + args.precision + ".mlir")
|
||||
vmfb_path = Path(path_str)
|
||||
path_str = model_name + "_" + args.precision + ".mlir"
|
||||
mlir_path = Path(path_str)
|
||||
shark_module = None
|
||||
|
||||
if not vmfb_path.exists():
|
||||
if args.device in ["cuda", "cpu"] and args.precision in [
|
||||
"fp16",
|
||||
"fp32",
|
||||
]:
|
||||
if args.device == "cuda" and args.precision in ["fp16", "fp32"]:
|
||||
# Downloading VMFB from shark_tank
|
||||
print("Downloading vmfb from shark tank.")
|
||||
download_public_file(
|
||||
"gs://shark_tank/langchain/" + str(vmfb_path),
|
||||
"gs://shark_tank/langchain/" + path_str,
|
||||
vmfb_path.absolute(),
|
||||
single_file=True,
|
||||
)
|
||||
@@ -60,7 +58,11 @@ class H2OGPTSHARKModel(torch.nn.Module):
|
||||
else:
|
||||
# Downloading MLIR from shark_tank
|
||||
download_public_file(
|
||||
"gs://shark_tank/langchain/" + str(mlir_path),
|
||||
"gs://shark_tank/langchain/"
|
||||
+ model_name
|
||||
+ "_"
|
||||
+ args.precision
|
||||
+ ".mlir",
|
||||
mlir_path.absolute(),
|
||||
single_file=True,
|
||||
)
|
||||
@@ -78,18 +80,16 @@ class H2OGPTSHARKModel(torch.nn.Module):
|
||||
mlir_dialect="linalg",
|
||||
)
|
||||
print(f"[DEBUG] generating vmfb.")
|
||||
shark_module = _compile_module(
|
||||
shark_module, extended_model_name, []
|
||||
)
|
||||
shark_module = _compile_module(shark_module, vmfb_path, [])
|
||||
print("Saved newly generated vmfb.")
|
||||
|
||||
if shark_module is None:
|
||||
if vmfb_path.exists():
|
||||
print("Compiled vmfb found. Loading it from: ", vmfb_path)
|
||||
shark_module = SharkInference(
|
||||
None, device=args.device, mlir_dialect="linalg"
|
||||
None, device=global_device, mlir_dialect="linalg"
|
||||
)
|
||||
shark_module.load_module(str(vmfb_path))
|
||||
shark_module.load_module(vmfb_path)
|
||||
print("Compiled vmfb loaded successfully.")
|
||||
else:
|
||||
raise ValueError("Unable to download/generate a vmfb.")
|
||||
@@ -102,7 +102,7 @@ class H2OGPTSHARKModel(torch.nn.Module):
|
||||
"forward",
|
||||
(input_ids.to(device="cpu"), attention_mask.to(device="cpu")),
|
||||
)
|
||||
).to(device=args.device)
|
||||
).to(device=global_device)
|
||||
return result
|
||||
|
||||
|
||||
@@ -118,14 +118,14 @@ def pad_or_truncate_inputs(
|
||||
num_add_token = max_padding_length - inp_shape[1]
|
||||
padded_input_ids = torch.cat(
|
||||
[
|
||||
torch.tensor([[11] * num_add_token]).to(device=args.device),
|
||||
torch.tensor([[11] * num_add_token]).to(device=global_device),
|
||||
input_ids,
|
||||
],
|
||||
dim=1,
|
||||
)
|
||||
padded_attention_mask = torch.cat(
|
||||
[
|
||||
torch.tensor([[0] * num_add_token]).to(device=args.device),
|
||||
torch.tensor([[0] * num_add_token]).to(device=global_device),
|
||||
attention_mask,
|
||||
],
|
||||
dim=1,
|
||||
@@ -328,7 +328,7 @@ class H2OTextGenerationPipeline(TextGenerationPipeline):
|
||||
model_inputs["input_ids"], model_inputs["attention_mask"]
|
||||
)
|
||||
|
||||
if args.precision == "fp16":
|
||||
if global_precision == "fp16":
|
||||
outputs = outputs.to(dtype=torch.float32)
|
||||
next_token_logits = outputs
|
||||
|
||||
@@ -455,7 +455,7 @@ class H2OTextGenerationPipeline(TextGenerationPipeline):
|
||||
if isinstance(eos_token_id, int):
|
||||
eos_token_id = [eos_token_id]
|
||||
self.eos_token_id_tensor = (
|
||||
torch.tensor(eos_token_id).to(device=args.device)
|
||||
torch.tensor(eos_token_id).to(device=global_device)
|
||||
if eos_token_id is not None
|
||||
else None
|
||||
)
|
||||
@@ -533,7 +533,7 @@ class H2OTextGenerationPipeline(TextGenerationPipeline):
|
||||
self.input_ids = torch.cat(
|
||||
[
|
||||
torch.tensor(self.truncated_input_ids)
|
||||
.to(device=args.device)
|
||||
.to(device=global_device)
|
||||
.unsqueeze(dim=0),
|
||||
self.input_ids,
|
||||
],
|
||||
@@ -612,9 +612,22 @@ class H2OTextGenerationPipeline(TextGenerationPipeline):
|
||||
**generate_kwargs,
|
||||
)
|
||||
out_b = generated_sequence.shape[0]
|
||||
generated_sequence = generated_sequence.reshape(
|
||||
in_b, out_b // in_b, *generated_sequence.shape[1:]
|
||||
)
|
||||
if self.framework == "pt":
|
||||
generated_sequence = generated_sequence.reshape(
|
||||
in_b, out_b // in_b, *generated_sequence.shape[1:]
|
||||
)
|
||||
elif self.framework == "tf":
|
||||
from transformers import is_tf_available
|
||||
|
||||
if is_tf_available():
|
||||
import tensorflow as tf
|
||||
|
||||
generated_sequence = tf.reshape(
|
||||
generated_sequence,
|
||||
(in_b, out_b // in_b, *generated_sequence.shape[1:]),
|
||||
)
|
||||
else:
|
||||
raise ValueError("TF not avaialble.")
|
||||
return {
|
||||
"generated_sequence": generated_sequence,
|
||||
"input_ids": input_ids,
|
||||
|
||||
@@ -6,6 +6,7 @@ huggingface_hub==0.15.1
|
||||
appdirs==1.4.4
|
||||
fire==0.5.0
|
||||
docutils==0.20.1
|
||||
torch==2.0.1
|
||||
evaluate==0.4.0
|
||||
rouge_score==0.1.2
|
||||
sacrebleu==2.3.1
|
||||
@@ -19,6 +20,7 @@ loralib==0.1.1
|
||||
bitsandbytes==0.39.0
|
||||
accelerate==0.20.3
|
||||
git+https://github.com/huggingface/peft.git@0b62b4378b4ce9367932c73540349da9a41bdea8
|
||||
transformers==4.30.2
|
||||
tokenizers==0.13.3
|
||||
APScheduler==3.10.1
|
||||
|
||||
@@ -61,46 +63,3 @@ text-generation==0.6.0
|
||||
tiktoken==0.4.0
|
||||
# optional: for OpenAI endpoint or embeddings (requires key)
|
||||
openai==0.27.8
|
||||
|
||||
# optional for chat with PDF
|
||||
langchain==0.0.235
|
||||
pypdf==3.12.2
|
||||
# avoid textract, requires old six
|
||||
#textract==1.6.5
|
||||
|
||||
# for HF embeddings
|
||||
sentence_transformers==2.2.2
|
||||
|
||||
# local vector db
|
||||
chromadb==0.3.25
|
||||
# server vector db
|
||||
#pymilvus==2.2.8
|
||||
|
||||
# weak url support, if can't install opencv etc. If comment-in this one, then comment-out unstructured[local-inference]==0.6.6
|
||||
# unstructured==0.8.1
|
||||
|
||||
# strong support for images
|
||||
# Requires on Ubuntu: sudo apt-get install libmagic-dev poppler-utils tesseract-ocr libtesseract-dev libreoffice
|
||||
unstructured[local-inference]==0.7.4
|
||||
#pdf2image==1.16.3
|
||||
#pytesseract==0.3.10
|
||||
pillow
|
||||
|
||||
pdfminer.six==20221105
|
||||
urllib3
|
||||
requests_file
|
||||
|
||||
#pdf2image==1.16.3
|
||||
#pytesseract==0.3.10
|
||||
tabulate==0.9.0
|
||||
# FYI pandoc already part of requirements.txt
|
||||
|
||||
# JSONLoader, but makes some trouble for some users
|
||||
# jq==1.4.1
|
||||
|
||||
# to check licenses
|
||||
# Run: pip-licenses|grep -v 'BSD\|Apache\|MIT'
|
||||
pip-licenses==4.3.0
|
||||
|
||||
# weaviate vector db
|
||||
weaviate-client==3.22.1
|
||||
|
||||
@@ -102,24 +102,11 @@ parser.add_argument(
|
||||
help="Group size for per_group weight quantization. Default: 128.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--download_vmfb",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="download vmfb from sharktank, system dependent, YMMV",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model_name",
|
||||
type=str,
|
||||
"--model_to_run",
|
||||
default="vicuna",
|
||||
choices=["vicuna", "llama2_7b", "llama2_70b"],
|
||||
help="Specify which model to run.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--hf_auth_token",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Specify your own huggingface authentication tokens for models like Llama2.",
|
||||
help="Vicuna/Llama version to run",
|
||||
)
|
||||
parser.add_argument("--download_vmfb", default=False, action=argparse.BooleanOptionalAction, help="download vmfb from sharktank, system dependent, YMMV")
|
||||
|
||||
|
||||
def brevitas〇matmul_rhs_group_quant〡shape(lhs: List[int], rhs: List[int], rhs_scale: List[int], rhs_zero_point: List[int], rhs_bit_width: int, rhs_group_size: int) -> List[int]:
|
||||
@@ -724,13 +711,20 @@ class ShardedVicuna(SharkLLMBase):
|
||||
quantize_model(
|
||||
get_model_impl(vicuna_model).layers,
|
||||
dtype=torch.float32,
|
||||
weight_quant_type="asym",
|
||||
weight_bit_width=weight_bit_width,
|
||||
weight_param_method="stats",
|
||||
weight_scale_precision="float",
|
||||
weight_quant_type="asym",
|
||||
weight_quant_granularity="per_group",
|
||||
weight_group_size=self.weight_group_size,
|
||||
quantize_weight_zero_point=False,
|
||||
input_bit_width=None,
|
||||
input_scale_type="float",
|
||||
input_param_method="stats",
|
||||
input_quant_type="asym",
|
||||
input_quant_granularity="per_tensor",
|
||||
quantize_input_zero_point=False,
|
||||
seqlen=2048,
|
||||
)
|
||||
print("Weight quantization applied.")
|
||||
|
||||
@@ -888,7 +882,6 @@ class UnshardedVicuna(SharkLLMBase):
|
||||
self,
|
||||
model_name,
|
||||
hf_model_path="TheBloke/vicuna-7B-1.1-HF",
|
||||
hf_auth_token: str = None,
|
||||
max_num_tokens=512,
|
||||
device="cuda",
|
||||
precision="fp32",
|
||||
@@ -902,15 +895,8 @@ class UnshardedVicuna(SharkLLMBase):
|
||||
download_vmfb=False,
|
||||
) -> None:
|
||||
super().__init__(model_name, hf_model_path, max_num_tokens)
|
||||
if "llama2" in self.model_name and hf_auth_token == None:
|
||||
raise ValueError(
|
||||
"HF auth token required. Pass it using --hf_auth_token flag."
|
||||
)
|
||||
self.hf_auth_token = hf_auth_token
|
||||
if self.model_name == "llama2_7b":
|
||||
if self.model_name == "llama2":
|
||||
self.hf_model_path = "meta-llama/Llama-2-7b-chat-hf"
|
||||
elif self.model_name == "llama2_70b":
|
||||
self.hf_model_path = "meta-llama/Llama-2-70b-chat-hf"
|
||||
print(f"[DEBUG] hf model name: {self.hf_model_path}")
|
||||
self.max_sequence_length = 256
|
||||
self.device = device
|
||||
@@ -949,7 +935,11 @@ class UnshardedVicuna(SharkLLMBase):
|
||||
)
|
||||
|
||||
def get_tokenizer(self):
|
||||
kwargs = {"use_auth_token": self.hf_auth_token}
|
||||
kwargs = {}
|
||||
if self.model_name == "llama2":
|
||||
kwargs = {
|
||||
"use_auth_token": "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"
|
||||
}
|
||||
if self.model_name == "codegen":
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
self.hf_model_path,
|
||||
@@ -964,10 +954,9 @@ class UnshardedVicuna(SharkLLMBase):
|
||||
return tokenizer
|
||||
|
||||
def get_src_model(self):
|
||||
kwargs = {
|
||||
"torch_dtype": torch.float,
|
||||
"use_auth_token": self.hf_auth_token,
|
||||
}
|
||||
kwargs = {"torch_dtype": torch.float}
|
||||
if self.model_name == "llama2":
|
||||
kwargs["use_auth_token"] = "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"
|
||||
vicuna_model = AutoModelForCausalLM.from_pretrained(
|
||||
self.hf_model_path,
|
||||
**kwargs,
|
||||
@@ -1033,7 +1022,6 @@ class UnshardedVicuna(SharkLLMBase):
|
||||
self.precision,
|
||||
self.weight_group_size,
|
||||
self.model_name,
|
||||
self.hf_auth_token,
|
||||
)
|
||||
|
||||
print(f"[DEBUG] generating torchscript graph")
|
||||
@@ -1198,7 +1186,6 @@ class UnshardedVicuna(SharkLLMBase):
|
||||
self.precision,
|
||||
self.weight_group_size,
|
||||
self.model_name,
|
||||
self.hf_auth_token,
|
||||
)
|
||||
|
||||
print(f"[DEBUG] generating torchscript graph")
|
||||
@@ -1353,8 +1340,7 @@ class UnshardedVicuna(SharkLLMBase):
|
||||
):
|
||||
if (self.device == "cuda" and self.precision == "fp16") or (
|
||||
self.device in ["cpu-sync", "cpu-task"]
|
||||
and self.precision == "int8"
|
||||
and self.download_vmfb
|
||||
and self.precision == "int8" and self.download_vmfb
|
||||
):
|
||||
download_public_file(
|
||||
f"gs://shark_tank/{self.model_name}/unsharded/vmfb/{self.first_vicuna_vmfb_path.name}",
|
||||
@@ -1376,8 +1362,7 @@ class UnshardedVicuna(SharkLLMBase):
|
||||
):
|
||||
if (self.device == "cuda" and self.precision == "fp16") or (
|
||||
self.device in ["cpu-sync", "cpu-task"]
|
||||
and self.precision == "int8"
|
||||
and self.download_vmfb
|
||||
and self.precision == "int8" and self.download_vmfb
|
||||
):
|
||||
download_public_file(
|
||||
f"gs://shark_tank/{self.model_name}/unsharded/vmfb/{self.second_vicuna_vmfb_path.name}",
|
||||
@@ -1573,6 +1558,15 @@ class UnshardedVicuna(SharkLLMBase):
|
||||
|
||||
if __name__ == "__main__":
|
||||
args, unknown = parser.parse_known_args()
|
||||
model_map = {
|
||||
"llama2_7b": "meta-llama/Llama-2-7b-chat-hf",
|
||||
"llama2_70b": "meta-llama/Llama-2-70b-chat-hf",
|
||||
"codegen": "Salesforce/codegen25-7b-multi",
|
||||
"vicuna1p3": "lmsys/vicuna-7b-v1.3",
|
||||
"vicuna": "TheBloke/vicuna-7B-1.1-HF",
|
||||
}
|
||||
|
||||
hf_model_id = model_map[args.model_to_run]
|
||||
|
||||
vic = None
|
||||
if not args.sharded:
|
||||
@@ -1598,8 +1592,8 @@ if __name__ == "__main__":
|
||||
)
|
||||
|
||||
vic = UnshardedVicuna(
|
||||
model_name=args.model_name,
|
||||
hf_auth_token=args.hf_auth_token,
|
||||
"vicuna",
|
||||
hf_model_id,
|
||||
device=args.device,
|
||||
precision=args.precision,
|
||||
first_vicuna_mlir_path=first_vic_mlir_path,
|
||||
@@ -1618,45 +1612,22 @@ if __name__ == "__main__":
|
||||
else:
|
||||
config_json = None
|
||||
vic = ShardedVicuna(
|
||||
model_name=args.model_name,
|
||||
"vicuna",
|
||||
hf_model_id,
|
||||
device=args.device,
|
||||
precision=args.precision,
|
||||
config_json=config_json,
|
||||
weight_group_size=args.weight_group_size,
|
||||
)
|
||||
if args.model_name == "vicuna":
|
||||
system_message = "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.\n"
|
||||
else:
|
||||
system_message = """System: You are a helpful, respectful and honest assistant. Always answer "
|
||||
as helpfully as possible, while being safe. Your answers should not
|
||||
include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal
|
||||
content. Please ensure that your responses are socially unbiased and positive
|
||||
in nature. If a question does not make any sense, or is not factually coherent,
|
||||
explain why instead of answering something not correct. If you don't know the
|
||||
answer to a question, please don't share false information."""
|
||||
system_message = "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.\n"
|
||||
prologue_prompt = "ASSISTANT:\n"
|
||||
|
||||
from apps.stable_diffusion.web.ui.stablelm_ui import chat, set_vicuna_model
|
||||
|
||||
history = []
|
||||
set_vicuna_model(vic)
|
||||
|
||||
model_list = {
|
||||
"vicuna": "vicuna=>TheBloke/vicuna-7B-1.1-HF",
|
||||
"llama2_7b": "llama2_7b=>meta-llama/Llama-2-7b-chat-hf",
|
||||
"llama2_70b": "llama2_70b=>meta-llama/Llama-2-70b-chat-hf",
|
||||
}
|
||||
while True:
|
||||
# TODO: Add break condition from user input
|
||||
user_prompt = input("User: ")
|
||||
history.append([user_prompt, ""])
|
||||
history = list(
|
||||
chat(
|
||||
system_message,
|
||||
history,
|
||||
model=model_list[args.model_name],
|
||||
device=args.device,
|
||||
precision=args.precision,
|
||||
cli=args.cli,
|
||||
)
|
||||
)[0]
|
||||
history.append([user_prompt,""])
|
||||
history = list(chat(system_message, history, model="vicuna=>TheBloke/vicuna-7B-1.1-HF", device=args.device, precision=args.precision, cli=args.cli))[0]
|
||||
|
||||
|
||||
@@ -12,12 +12,11 @@ class FirstVicuna(torch.nn.Module):
|
||||
precision="fp32",
|
||||
weight_group_size=128,
|
||||
model_name="vicuna",
|
||||
hf_auth_token: str = None,
|
||||
):
|
||||
super().__init__()
|
||||
kwargs = {"torch_dtype": torch.float32}
|
||||
if "llama2" in model_name:
|
||||
kwargs["use_auth_token"] = hf_auth_token
|
||||
if model_name == "llama2":
|
||||
kwargs["use_auth_token"] = "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"
|
||||
self.model = AutoModelForCausalLM.from_pretrained(
|
||||
model_path, low_cpu_mem_usage=True, **kwargs
|
||||
)
|
||||
@@ -27,13 +26,20 @@ class FirstVicuna(torch.nn.Module):
|
||||
quantize_model(
|
||||
get_model_impl(self.model).layers,
|
||||
dtype=torch.float32,
|
||||
weight_quant_type="asym",
|
||||
weight_bit_width=weight_bit_width,
|
||||
weight_param_method="stats",
|
||||
weight_scale_precision="float",
|
||||
weight_quant_type="asym",
|
||||
weight_quant_granularity="per_group",
|
||||
weight_group_size=weight_group_size,
|
||||
quantize_weight_zero_point=False,
|
||||
input_bit_width=None,
|
||||
input_scale_type="float",
|
||||
input_param_method="stats",
|
||||
input_quant_type="asym",
|
||||
input_quant_granularity="per_tensor",
|
||||
quantize_input_zero_point=False,
|
||||
seqlen=2048,
|
||||
)
|
||||
print("Weight quantization applied.")
|
||||
|
||||
@@ -55,12 +61,11 @@ class SecondVicuna(torch.nn.Module):
|
||||
precision="fp32",
|
||||
weight_group_size=128,
|
||||
model_name="vicuna",
|
||||
hf_auth_token: str = None,
|
||||
):
|
||||
super().__init__()
|
||||
kwargs = {"torch_dtype": torch.float32}
|
||||
if "llama2" in model_name:
|
||||
kwargs["use_auth_token"] = hf_auth_token
|
||||
if model_name == "llama2":
|
||||
kwargs["use_auth_token"] = "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"
|
||||
self.model = AutoModelForCausalLM.from_pretrained(
|
||||
model_path, low_cpu_mem_usage=True, **kwargs
|
||||
)
|
||||
@@ -70,13 +75,20 @@ class SecondVicuna(torch.nn.Module):
|
||||
quantize_model(
|
||||
get_model_impl(self.model).layers,
|
||||
dtype=torch.float32,
|
||||
weight_quant_type="asym",
|
||||
weight_bit_width=weight_bit_width,
|
||||
weight_param_method="stats",
|
||||
weight_scale_precision="float",
|
||||
weight_quant_type="asym",
|
||||
weight_quant_granularity="per_group",
|
||||
weight_group_size=weight_group_size,
|
||||
quantize_weight_zero_point=False,
|
||||
input_bit_width=None,
|
||||
input_scale_type="float",
|
||||
input_param_method="stats",
|
||||
input_quant_type="asym",
|
||||
input_quant_granularity="per_tensor",
|
||||
quantize_input_zero_point=False,
|
||||
seqlen=2048,
|
||||
)
|
||||
print("Weight quantization applied.")
|
||||
|
||||
|
||||
@@ -400,13 +400,6 @@ p.add_argument(
|
||||
help="Load and unload models for low VRAM.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--hf_auth_token",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Specify your own huggingface authentication tokens for models like Llama2.",
|
||||
)
|
||||
|
||||
##############################################################################
|
||||
# IREE - Vulkan supported flags
|
||||
##############################################################################
|
||||
|
||||
@@ -117,12 +117,16 @@ body {
|
||||
padding: 0 var(--size-4) !important;
|
||||
}
|
||||
|
||||
.container {
|
||||
background-color: black !important;
|
||||
padding-top: var(--size-5) !important;
|
||||
}
|
||||
|
||||
#ui_title {
|
||||
padding: var(--size-2) 0 0 var(--size-1);
|
||||
}
|
||||
|
||||
#top_logo {
|
||||
color: transparent;
|
||||
background-color: transparent;
|
||||
border-radius: 0 !important;
|
||||
border: 0;
|
||||
|
||||
@@ -21,132 +21,129 @@ def user(message, history):
|
||||
|
||||
|
||||
sharkModel = 0
|
||||
sharded_model = 0
|
||||
h2ogpt_model = 0
|
||||
|
||||
past_key_values = None
|
||||
|
||||
model_map = {
|
||||
"codegen": "Salesforce/codegen25-7b-multi",
|
||||
"vicuna1p3": "lmsys/vicuna-7b-v1.3",
|
||||
"vicuna": "TheBloke/vicuna-7B-1.1-HF",
|
||||
"StableLM": "stabilityai/stablelm-tuned-alpha-3b",
|
||||
}
|
||||
|
||||
# NOTE: Each `model_name` should have its own start message
|
||||
start_message = """
|
||||
SHARK DocuChat
|
||||
Chat with an AI, contextualized with provided files.
|
||||
"""
|
||||
start_message = {
|
||||
"StableLM": (
|
||||
"<|SYSTEM|># StableLM Tuned (Alpha version)"
|
||||
"\n- StableLM is a helpful and harmless open-source AI language model "
|
||||
"developed by StabilityAI."
|
||||
"\n- StableLM is excited to be able to help the user, but will refuse "
|
||||
"to do anything that could be considered harmful to the user."
|
||||
"\n- StableLM is more than just an information source, StableLM is also "
|
||||
"able to write poetry, short stories, and make jokes."
|
||||
"\n- StableLM will refuse to participate in anything that "
|
||||
"could harm a human."
|
||||
),
|
||||
"vicuna": (
|
||||
"A chat between a curious user and an artificial intelligence assistant. "
|
||||
"The assistant gives helpful, detailed, and polite answers to the user's "
|
||||
"questions.\n"
|
||||
),
|
||||
"vicuna1p3": (
|
||||
"A chat between a curious user and an artificial intelligence assistant. "
|
||||
"The assistant gives helpful, detailed, and polite answers to the user's "
|
||||
"questions.\n"
|
||||
),
|
||||
"codegen": "",
|
||||
}
|
||||
|
||||
|
||||
def create_prompt(history):
|
||||
system_message = start_message
|
||||
def create_prompt(model_name, history):
|
||||
system_message = start_message[model_name]
|
||||
|
||||
conversation = "".join(["".join([item[0], item[1]]) for item in history])
|
||||
if model_name in ["StableLM", "vicuna", "vicuna1p3"]:
|
||||
conversation = "".join(
|
||||
[
|
||||
"".join(["<|USER|>" + item[0], "<|ASSISTANT|>" + item[1]])
|
||||
for item in history
|
||||
]
|
||||
)
|
||||
else:
|
||||
conversation = "".join(
|
||||
["".join([item[0], item[1]]) for item in history]
|
||||
)
|
||||
|
||||
msg = system_message + conversation
|
||||
msg = msg.strip()
|
||||
return msg
|
||||
|
||||
|
||||
def chat(curr_system_message, history, device, precision):
|
||||
def chat(curr_system_message, history, model, device, precision):
|
||||
args.run_docuchat_web = True
|
||||
global sharded_model
|
||||
global past_key_values
|
||||
global h2ogpt_model
|
||||
global h2ogpt_tokenizer
|
||||
global model_state
|
||||
global langchain
|
||||
global userpath_selector
|
||||
|
||||
if h2ogpt_model == 0:
|
||||
if "cuda" in device:
|
||||
device = "cuda"
|
||||
elif "sync" in device:
|
||||
device = "cpu"
|
||||
elif "task" in device:
|
||||
device = "cpu"
|
||||
elif "vulkan" in device:
|
||||
device = "vulkan"
|
||||
else:
|
||||
print("unrecognized device")
|
||||
model_name, model_path = list(map(str.strip, model.split("=>")))
|
||||
print(f"In chat for {model_name}")
|
||||
|
||||
args.device = device
|
||||
args.precision = precision
|
||||
# if h2ogpt_model == 0:
|
||||
# if "cuda" in device:
|
||||
# device = "cuda"
|
||||
# elif "sync" in device:
|
||||
# device = "cpu-sync"
|
||||
# elif "task" in device:
|
||||
# device = "cpu-task"
|
||||
# elif "vulkan" in device:
|
||||
# device = "vulkan"
|
||||
# else:
|
||||
# print("unrecognized device")
|
||||
|
||||
from apps.language_models.langchain.gen import Langchain
|
||||
# max_toks = 128 if model_name == "codegen" else 512
|
||||
# h2ogpt_model = UnshardedVicuna(
|
||||
# model_name,
|
||||
# hf_model_path=model_path,
|
||||
# device=device,
|
||||
# precision=precision,
|
||||
# max_num_tokens=max_toks,
|
||||
# )
|
||||
# prompt = create_prompt(model_name, history)
|
||||
# print("prompt = ", prompt)
|
||||
|
||||
langchain = Langchain(device, precision)
|
||||
h2ogpt_model, h2ogpt_tokenizer, _ = langchain.get_model(
|
||||
load_8bit=True
|
||||
if device == "cuda"
|
||||
else False, # load model in 4bit if device is cuda to save memory
|
||||
load_gptq="",
|
||||
use_safetensors=False,
|
||||
infer_devices=True,
|
||||
device=device,
|
||||
base_model="h2oai/h2ogpt-gm-oasst1-en-2048-falcon-7b-v3",
|
||||
inference_server="",
|
||||
tokenizer_base_model="h2oai/h2ogpt-gm-oasst1-en-2048-falcon-7b-v3",
|
||||
lora_weights="",
|
||||
gpu_id=0,
|
||||
reward_type=None,
|
||||
local_files_only=False,
|
||||
resume_download=True,
|
||||
use_auth_token=False,
|
||||
trust_remote_code=True,
|
||||
offload_folder=None,
|
||||
compile_model=False,
|
||||
verbose=False,
|
||||
)
|
||||
model_state = dict(
|
||||
model=h2ogpt_model,
|
||||
tokenizer=h2ogpt_tokenizer,
|
||||
device=device,
|
||||
base_model="h2oai/h2ogpt-gm-oasst1-en-2048-falcon-7b-v3",
|
||||
tokenizer_base_model="h2oai/h2ogpt-gm-oasst1-en-2048-falcon-7b-v3",
|
||||
lora_weights="",
|
||||
inference_server="",
|
||||
prompt_type=None,
|
||||
prompt_dict=None,
|
||||
)
|
||||
|
||||
prompt = create_prompt(history)
|
||||
output = langchain.evaluate(
|
||||
model_state=model_state,
|
||||
my_db_state=None,
|
||||
instruction=prompt,
|
||||
iinput="",
|
||||
context="",
|
||||
stream_output=True,
|
||||
prompt_type="prompt_answer",
|
||||
prompt_dict={
|
||||
"promptA": "",
|
||||
"promptB": "",
|
||||
"PreInstruct": "<|prompt|>",
|
||||
"PreInput": None,
|
||||
"PreResponse": "<|answer|>",
|
||||
"terminate_response": [
|
||||
"<|prompt|>",
|
||||
"<|answer|>",
|
||||
"<|endoftext|>",
|
||||
],
|
||||
"chat_sep": "<|endoftext|>",
|
||||
"chat_turn_sep": "<|endoftext|>",
|
||||
"humanstr": "<|prompt|>",
|
||||
"botstr": "<|answer|>",
|
||||
"generates_leading_space": False,
|
||||
},
|
||||
temperature=0.1,
|
||||
top_p=0.75,
|
||||
top_k=40,
|
||||
num_beams=1,
|
||||
max_new_tokens=256,
|
||||
min_new_tokens=0,
|
||||
early_stopping=False,
|
||||
max_time=180,
|
||||
repetition_penalty=1.07,
|
||||
num_return_sequences=1,
|
||||
do_sample=False,
|
||||
chat=True,
|
||||
instruction_nochat=prompt,
|
||||
iinput_nochat="",
|
||||
langchain_mode="UserData",
|
||||
langchain_action=LangChainAction.QUERY.value,
|
||||
top_k_docs=3,
|
||||
chunk=True,
|
||||
chunk_size=512,
|
||||
document_choice=[DocumentChoices.All_Relevant.name],
|
||||
# for partial_text in h2ogpt_model.generate(prompt):
|
||||
# history[-1][1] = partial_text
|
||||
# yield history
|
||||
output = gen.evaluate(
|
||||
None, # model_state
|
||||
None, # my_db_state
|
||||
None, # instruction
|
||||
None, # iinput
|
||||
history, # context
|
||||
False, # stream_output
|
||||
None, # prompt_type
|
||||
None, # prompt_dict
|
||||
None, # temperature
|
||||
None, # top_p
|
||||
None, # top_k
|
||||
None, # num_beams
|
||||
None, # max_new_tokens
|
||||
None, # min_new_tokens
|
||||
None, # early_stopping
|
||||
None, # max_time
|
||||
None, # repetition_penalty
|
||||
None, # num_return_sequences
|
||||
False, # do_sample
|
||||
False, # chat
|
||||
None, # instruction_nochat
|
||||
curr_system_message, # iinput_nochat
|
||||
"Disabled", # langchain_mode
|
||||
LangChainAction.QUERY.value, # langchain_action
|
||||
3, # top_k_docs
|
||||
True, # chunk
|
||||
512, # chunk_size
|
||||
[DocumentChoices.All_Relevant.name], # document_choice
|
||||
concurrency_count=1,
|
||||
memory_restriction_level=2,
|
||||
raise_generate_gpu_exceptions=False,
|
||||
@@ -157,13 +154,9 @@ def chat(curr_system_message, history, device, precision):
|
||||
db_type="chroma",
|
||||
n_jobs=-1,
|
||||
first_para=False,
|
||||
max_max_time=60 * 2,
|
||||
model_state0=model_state,
|
||||
model_lock=True,
|
||||
user_path=userpath_selector.value,
|
||||
)
|
||||
for partial_text in output:
|
||||
history[-1][1] = partial_text["response"]
|
||||
history[-1][1] = partial_text
|
||||
yield history
|
||||
|
||||
return history
|
||||
@@ -171,6 +164,14 @@ def chat(curr_system_message, history, device, precision):
|
||||
|
||||
with gr.Blocks(title="H2OGPT") as h2ogpt_web:
|
||||
with gr.Row():
|
||||
model_choices = list(
|
||||
map(lambda x: f"{x[0]: <10} => {x[1]}", model_map.items())
|
||||
)
|
||||
model = gr.Dropdown(
|
||||
label="Select Model",
|
||||
value=model_choices[0],
|
||||
choices=model_choices,
|
||||
)
|
||||
supported_devices = available_devices
|
||||
enabled = len(supported_devices) > 0
|
||||
# show cpu-task device first in list for chatbot
|
||||
@@ -196,14 +197,6 @@ with gr.Blocks(title="H2OGPT") as h2ogpt_web:
|
||||
],
|
||||
visible=True,
|
||||
)
|
||||
userpath_selector = gr.Textbox(
|
||||
label="Document Directory",
|
||||
value=str(
|
||||
os.path.abspath("apps/language_models/langchain/user_path/")
|
||||
),
|
||||
interactive=True,
|
||||
container=True,
|
||||
)
|
||||
chatbot = gr.Chatbot(height=500)
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
@@ -227,7 +220,7 @@ with gr.Blocks(title="H2OGPT") as h2ogpt_web:
|
||||
fn=user, inputs=[msg, chatbot], outputs=[msg, chatbot], queue=False
|
||||
).then(
|
||||
fn=chat,
|
||||
inputs=[system_msg, chatbot, device, precision],
|
||||
inputs=[system_msg, chatbot, model, device, precision],
|
||||
outputs=[chatbot],
|
||||
queue=True,
|
||||
)
|
||||
@@ -235,7 +228,7 @@ with gr.Blocks(title="H2OGPT") as h2ogpt_web:
|
||||
fn=user, inputs=[msg, chatbot], outputs=[msg, chatbot], queue=False
|
||||
).then(
|
||||
fn=chat,
|
||||
inputs=[system_msg, chatbot, device, precision],
|
||||
inputs=[system_msg, chatbot, model, device, precision],
|
||||
outputs=[chatbot],
|
||||
queue=True,
|
||||
)
|
||||
|
||||
@@ -31,16 +31,7 @@ model_map = {
|
||||
|
||||
# NOTE: Each `model_name` should have its own start message
|
||||
start_message = {
|
||||
"llama2_7b": (
|
||||
"System: You are a helpful, respectful and honest assistant. Always answer "
|
||||
"as helpfully as possible, while being safe. Your answers should not "
|
||||
"include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal "
|
||||
"content. Please ensure that your responses are socially unbiased and positive "
|
||||
"in nature. If a question does not make any sense, or is not factually coherent, "
|
||||
"explain why instead of answering something not correct. If you don't know the "
|
||||
"answer to a question, please don't share false information."
|
||||
),
|
||||
"llama2_70b": (
|
||||
"llama2": (
|
||||
"System: You are a helpful, respectful and honest assistant. Always answer "
|
||||
"as helpfully as possible, while being safe. Your answers should not "
|
||||
"include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal "
|
||||
@@ -77,13 +68,7 @@ start_message = {
|
||||
def create_prompt(model_name, history):
|
||||
system_message = start_message[model_name]
|
||||
|
||||
if model_name in [
|
||||
"StableLM",
|
||||
"vicuna",
|
||||
"vicuna1p3",
|
||||
"llama2_7b",
|
||||
"llama2_70b",
|
||||
]:
|
||||
if model_name in ["StableLM", "vicuna", "vicuna1p3", "llama2"]:
|
||||
conversation = "".join(
|
||||
[
|
||||
"".join(["<|USER|>" + item[0], "<|ASSISTANT|>" + item[1]])
|
||||
@@ -112,17 +97,10 @@ def chat(curr_system_message, history, model, device, precision, cli=True):
|
||||
global vicuna_model
|
||||
model_name, model_path = list(map(str.strip, model.split("=>")))
|
||||
|
||||
if model_name in [
|
||||
"vicuna",
|
||||
"vicuna1p3",
|
||||
"codegen",
|
||||
"llama2_7b",
|
||||
"llama2_70b",
|
||||
]:
|
||||
if model_name in ["vicuna", "vicuna1p3", "codegen", "llama2"]:
|
||||
from apps.language_models.scripts.vicuna import (
|
||||
UnshardedVicuna,
|
||||
)
|
||||
from apps.stable_diffusion.src import args
|
||||
|
||||
if vicuna_model == 0:
|
||||
if "cuda" in device:
|
||||
@@ -140,7 +118,6 @@ def chat(curr_system_message, history, model, device, precision, cli=True):
|
||||
vicuna_model = UnshardedVicuna(
|
||||
model_name,
|
||||
hf_model_path=model_path,
|
||||
hf_auth_token=args.hf_auth_token,
|
||||
device=device,
|
||||
precision=precision,
|
||||
max_num_tokens=max_toks,
|
||||
|
||||
@@ -395,7 +395,7 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
|
||||
value=args.scheduler,
|
||||
choices=scheduler_list,
|
||||
)
|
||||
with gr.Column():
|
||||
with gr.Group():
|
||||
save_metadata_to_png = gr.Checkbox(
|
||||
label="Save prompt information to PNG",
|
||||
value=args.write_metadata_to_png,
|
||||
|
||||
@@ -39,5 +39,5 @@ joblib # for langchain
|
||||
pefile
|
||||
pyinstaller
|
||||
|
||||
# vicuna quantization
|
||||
brevitas @ git+https://github.com/Xilinx/brevitas.git@dev
|
||||
# low precision vicuna
|
||||
brevitas @ git+https://github.com/Xilinx/brevitas.git@llm
|
||||
|
||||
@@ -159,3 +159,5 @@ if [[ -z "${CONDA_PREFIX}" && "$SKIP_VENV" != "1" ]]; then
|
||||
echo "${Green}Before running examples activate venv with:"
|
||||
echo " ${Green}source $VENV_DIR/bin/activate"
|
||||
fi
|
||||
|
||||
$PYTHON -m pip install git+https://github.com/Xilinx/brevitas.git@llm
|
||||
|
||||
Reference in New Issue
Block a user