mirror of
https://github.com/nod-ai/SHARK-Studio.git
synced 2026-04-20 03:00:34 -04:00
Compare commits
1 Commits
20230818.8
...
dlrm-train
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
16daba99fe |
8
.github/workflows/nightly.yml
vendored
8
.github/workflows/nightly.yml
vendored
@@ -51,11 +51,11 @@ jobs:
|
||||
run: |
|
||||
./setup_venv.ps1
|
||||
$env:SHARK_PACKAGE_VERSION=${{ env.package_version }}
|
||||
pip wheel -v -w dist . --pre -f https://download.pytorch.org/whl/nightly/cpu -f https://llvm.github.io/torch-mlir/package-index/ -f https://nod-ai.github.io/SRT/pip-release-links.html
|
||||
pip wheel -v -w dist . --pre -f https://download.pytorch.org/whl/nightly/cpu -f https://llvm.github.io/torch-mlir/package-index/ -f https://nod-ai.github.io/SHARK-Runtime/pip-release-links.html
|
||||
python process_skipfiles.py
|
||||
pyinstaller .\apps\stable_diffusion\shark_sd.spec
|
||||
mv ./dist/nodai_shark_studio.exe ./dist/nodai_shark_studio_${{ env.package_version_ }}.exe
|
||||
signtool sign /f c:\g\shark_02152023.cer /fd certHash /csp "eToken Base Cryptographic Provider" /k "${{ secrets.CI_CERT }}" ./dist/nodai_shark_studio_${{ env.package_version_ }}.exe
|
||||
signtool sign /f c:\g\shark_02152023.cer /csp "eToken Base Cryptographic Provider" /k "${{ secrets.CI_CERT }}" ./dist/nodai_shark_studio_${{ env.package_version_ }}.exe
|
||||
|
||||
- name: Upload Release Assets
|
||||
id: upload-release-assets
|
||||
@@ -104,7 +104,7 @@ jobs:
|
||||
echo "DATE=$(date +'%Y-%m-%d')" >> $GITHUB_ENV
|
||||
python -m pip install --upgrade pip
|
||||
python -m pip install flake8 pytest toml
|
||||
if [ -f requirements.txt ]; then pip install -r requirements.txt -f https://llvm.github.io/torch-mlir/package-index/ -f https://nod-ai.github.io/SRT/pip-release-links.html; fi
|
||||
if [ -f requirements.txt ]; then pip install -r requirements.txt -f https://llvm.github.io/torch-mlir/package-index/ -f https://nod-ai.github.io/SHARK-Runtime/pip-release-links.html; fi
|
||||
- name: Lint with flake8
|
||||
run: |
|
||||
# stop the build if there are Python syntax errors or undefined names
|
||||
@@ -144,7 +144,7 @@ jobs:
|
||||
source shark.venv/bin/activate
|
||||
package_version="$(printf '%(%Y%m%d)T.${{ github.run_number }}')"
|
||||
SHARK_PACKAGE_VERSION=${package_version} \
|
||||
pip wheel -v -w wheelhouse . --pre -f https://download.pytorch.org/whl/nightly/torch -f https://llvm.github.io/torch-mlir/package-index/ -f https://nod-ai.github.io/SRT/pip-release-links.html
|
||||
pip wheel -v -w wheelhouse . --pre -f https://download.pytorch.org/whl/nightly/torch -f https://llvm.github.io/torch-mlir/package-index/ -f https://nod-ai.github.io/SHARK-Runtime/pip-release-links.html
|
||||
# Install the built wheel
|
||||
pip install ./wheelhouse/nodai*
|
||||
# Validate the Models
|
||||
|
||||
2
.gitmodules
vendored
2
.gitmodules
vendored
@@ -1,4 +1,4 @@
|
||||
[submodule "inference/thirdparty/shark-runtime"]
|
||||
path = inference/thirdparty/shark-runtime
|
||||
url =https://github.com/nod-ai/SRT.git
|
||||
url =https://github.com/nod-ai/SHARK-Runtime.git
|
||||
branch = shark-06032022
|
||||
|
||||
@@ -170,7 +170,7 @@ python -m pip install --upgrade pip
|
||||
This step pip installs SHARK and related packages on Linux Python 3.8, 3.10 and 3.11 and macOS / Windows Python 3.11
|
||||
|
||||
```shell
|
||||
pip install nodai-shark -f https://nod-ai.github.io/SHARK/package-index/ -f https://llvm.github.io/torch-mlir/package-index/ -f https://nod-ai.github.io/SRT/pip-release-links.html --extra-index-url https://download.pytorch.org/whl/nightly/cpu
|
||||
pip install nodai-shark -f https://nod-ai.github.io/SHARK/package-index/ -f https://llvm.github.io/torch-mlir/package-index/ -f https://nod-ai.github.io/SHARK-Runtime/pip-release-links.html --extra-index-url https://download.pytorch.org/whl/nightly/cpu
|
||||
```
|
||||
|
||||
### Run shark tank model tests.
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
"""Load question answering chains."""
|
||||
from __future__ import annotations
|
||||
from typing import (
|
||||
Any,
|
||||
@@ -10,34 +11,23 @@ from typing import (
|
||||
Union,
|
||||
Protocol,
|
||||
)
|
||||
import inspect
|
||||
import json
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
import yaml
|
||||
from abc import ABC, abstractmethod
|
||||
import langchain
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
from langchain.callbacks.base import BaseCallbackManager
|
||||
from langchain.chains.question_answering import stuff_prompt
|
||||
from langchain.prompts.base import BasePromptTemplate
|
||||
from langchain.docstore.document import Document
|
||||
from abc import ABC, abstractmethod
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.callbacks.manager import (
|
||||
CallbackManager,
|
||||
CallbackManagerForChainRun,
|
||||
Callbacks,
|
||||
)
|
||||
from langchain.load.serializable import Serializable
|
||||
from langchain.schema import RUN_KEY, BaseMemory, RunInfo
|
||||
from langchain.input import get_colored_text
|
||||
from langchain.load.dump import dumpd
|
||||
from langchain.prompts.prompt import PromptTemplate
|
||||
from langchain.schema import LLMResult, PromptValue
|
||||
from pydantic import Extra, Field, root_validator, validator
|
||||
|
||||
|
||||
def _get_verbosity() -> bool:
|
||||
return langchain.verbose
|
||||
from pydantic import Extra, Field, root_validator
|
||||
|
||||
|
||||
def format_document(doc: Document, prompt: BasePromptTemplate) -> str:
|
||||
@@ -58,413 +48,6 @@ def format_document(doc: Document, prompt: BasePromptTemplate) -> str:
|
||||
return prompt.format(**document_info)
|
||||
|
||||
|
||||
class Chain(Serializable, ABC):
|
||||
"""Base interface that all chains should implement."""
|
||||
|
||||
memory: Optional[BaseMemory] = None
|
||||
callbacks: Callbacks = Field(default=None, exclude=True)
|
||||
callback_manager: Optional[BaseCallbackManager] = Field(
|
||||
default=None, exclude=True
|
||||
)
|
||||
verbose: bool = Field(
|
||||
default_factory=_get_verbosity
|
||||
) # Whether to print the response text
|
||||
tags: Optional[List[str]] = None
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@property
|
||||
def _chain_type(self) -> str:
|
||||
raise NotImplementedError("Saving not supported for this chain type.")
|
||||
|
||||
@root_validator()
|
||||
def raise_deprecation(cls, values: Dict) -> Dict:
|
||||
"""Raise deprecation warning if callback_manager is used."""
|
||||
if values.get("callback_manager") is not None:
|
||||
warnings.warn(
|
||||
"callback_manager is deprecated. Please use callbacks instead.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
values["callbacks"] = values.pop("callback_manager", None)
|
||||
return values
|
||||
|
||||
@validator("verbose", pre=True, always=True)
|
||||
def set_verbose(cls, verbose: Optional[bool]) -> bool:
|
||||
"""If verbose is None, set it.
|
||||
|
||||
This allows users to pass in None as verbose to access the global setting.
|
||||
"""
|
||||
if verbose is None:
|
||||
return _get_verbosity()
|
||||
else:
|
||||
return verbose
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def input_keys(self) -> List[str]:
|
||||
"""Input keys this chain expects."""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def output_keys(self) -> List[str]:
|
||||
"""Output keys this chain expects."""
|
||||
|
||||
def _validate_inputs(self, inputs: Dict[str, Any]) -> None:
|
||||
"""Check that all inputs are present."""
|
||||
missing_keys = set(self.input_keys).difference(inputs)
|
||||
if missing_keys:
|
||||
raise ValueError(f"Missing some input keys: {missing_keys}")
|
||||
|
||||
def _validate_outputs(self, outputs: Dict[str, Any]) -> None:
|
||||
missing_keys = set(self.output_keys).difference(outputs)
|
||||
if missing_keys:
|
||||
raise ValueError(f"Missing some output keys: {missing_keys}")
|
||||
|
||||
@abstractmethod
|
||||
def _call(
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Run the logic of this chain and return the output."""
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: Union[Dict[str, Any], Any],
|
||||
return_only_outputs: bool = False,
|
||||
callbacks: Callbacks = None,
|
||||
*,
|
||||
tags: Optional[List[str]] = None,
|
||||
include_run_info: bool = False,
|
||||
) -> Dict[str, Any]:
|
||||
"""Run the logic of this chain and add to output if desired.
|
||||
|
||||
Args:
|
||||
inputs: Dictionary of inputs, or single input if chain expects
|
||||
only one param.
|
||||
return_only_outputs: boolean for whether to return only outputs in the
|
||||
response. If True, only new keys generated by this chain will be
|
||||
returned. If False, both input keys and new keys generated by this
|
||||
chain will be returned. Defaults to False.
|
||||
callbacks: Callbacks to use for this chain run. If not provided, will
|
||||
use the callbacks provided to the chain.
|
||||
include_run_info: Whether to include run info in the response. Defaults
|
||||
to False.
|
||||
"""
|
||||
input_docs = inputs["input_documents"]
|
||||
missing_keys = set(self.input_keys).difference(inputs)
|
||||
if missing_keys:
|
||||
raise ValueError(f"Missing some input keys: {missing_keys}")
|
||||
|
||||
callback_manager = CallbackManager.configure(
|
||||
callbacks, self.callbacks, self.verbose, tags, self.tags
|
||||
)
|
||||
run_manager = callback_manager.on_chain_start(
|
||||
dumpd(self),
|
||||
inputs,
|
||||
)
|
||||
|
||||
if "is_first" in inputs.keys() and not inputs["is_first"]:
|
||||
run_manager_ = run_manager
|
||||
input_list = [inputs]
|
||||
stop = None
|
||||
prompts = []
|
||||
for inputs in input_list:
|
||||
selected_inputs = {
|
||||
k: inputs[k] for k in self.prompt.input_variables
|
||||
}
|
||||
prompt = self.prompt.format_prompt(**selected_inputs)
|
||||
_colored_text = get_colored_text(prompt.to_string(), "green")
|
||||
_text = "Prompt after formatting:\n" + _colored_text
|
||||
if run_manager_:
|
||||
run_manager_.on_text(_text, end="\n", verbose=self.verbose)
|
||||
if "stop" in inputs and inputs["stop"] != stop:
|
||||
raise ValueError(
|
||||
"If `stop` is present in any inputs, should be present in all."
|
||||
)
|
||||
prompts.append(prompt)
|
||||
|
||||
prompt_strings = [p.to_string() for p in prompts]
|
||||
prompts = prompt_strings
|
||||
callbacks = run_manager_.get_child() if run_manager_ else None
|
||||
tags = None
|
||||
|
||||
"""Run the LLM on the given prompt and input."""
|
||||
# If string is passed in directly no errors will be raised but outputs will
|
||||
# not make sense.
|
||||
if not isinstance(prompts, list):
|
||||
raise ValueError(
|
||||
"Argument 'prompts' is expected to be of type List[str], received"
|
||||
f" argument of type {type(prompts)}."
|
||||
)
|
||||
params = self.llm.dict()
|
||||
params["stop"] = stop
|
||||
options = {"stop": stop}
|
||||
disregard_cache = self.llm.cache is not None and not self.llm.cache
|
||||
callback_manager = CallbackManager.configure(
|
||||
callbacks,
|
||||
self.llm.callbacks,
|
||||
self.llm.verbose,
|
||||
tags,
|
||||
self.llm.tags,
|
||||
)
|
||||
if langchain.llm_cache is None or disregard_cache:
|
||||
# This happens when langchain.cache is None, but self.cache is True
|
||||
if self.llm.cache is not None and self.cache:
|
||||
raise ValueError(
|
||||
"Asked to cache, but no cache found at `langchain.cache`."
|
||||
)
|
||||
run_manager_ = callback_manager.on_llm_start(
|
||||
dumpd(self),
|
||||
prompts,
|
||||
invocation_params=params,
|
||||
options=options,
|
||||
)
|
||||
|
||||
generations = []
|
||||
for prompt in prompts:
|
||||
inputs_ = prompt
|
||||
num_workers = None
|
||||
batch_size = None
|
||||
|
||||
if num_workers is None:
|
||||
if self.llm.pipeline._num_workers is None:
|
||||
num_workers = 0
|
||||
else:
|
||||
num_workers = self.llm.pipeline._num_workers
|
||||
if batch_size is None:
|
||||
if self.llm.pipeline._batch_size is None:
|
||||
batch_size = 1
|
||||
else:
|
||||
batch_size = self.llm.pipeline._batch_size
|
||||
|
||||
preprocess_params = {}
|
||||
generate_kwargs = {}
|
||||
preprocess_params.update(generate_kwargs)
|
||||
forward_params = generate_kwargs
|
||||
postprocess_params = {}
|
||||
# Fuse __init__ params and __call__ params without modifying the __init__ ones.
|
||||
preprocess_params = {
|
||||
**self.llm.pipeline._preprocess_params,
|
||||
**preprocess_params,
|
||||
}
|
||||
forward_params = {
|
||||
**self.llm.pipeline._forward_params,
|
||||
**forward_params,
|
||||
}
|
||||
postprocess_params = {
|
||||
**self.llm.pipeline._postprocess_params,
|
||||
**postprocess_params,
|
||||
}
|
||||
|
||||
self.llm.pipeline.call_count += 1
|
||||
if (
|
||||
self.llm.pipeline.call_count > 10
|
||||
and self.llm.pipeline.framework == "pt"
|
||||
and self.llm.pipeline.device.type == "cuda"
|
||||
):
|
||||
warnings.warn(
|
||||
"You seem to be using the pipelines sequentially on GPU. In order to maximize efficiency please use a"
|
||||
" dataset",
|
||||
UserWarning,
|
||||
)
|
||||
|
||||
model_inputs = self.llm.pipeline.preprocess(
|
||||
inputs_, **preprocess_params
|
||||
)
|
||||
model_outputs = self.llm.pipeline.forward(
|
||||
model_inputs, **forward_params
|
||||
)
|
||||
model_outputs["process"] = False
|
||||
return model_outputs
|
||||
output = LLMResult(generations=generations)
|
||||
run_manager_.on_llm_end(output)
|
||||
if run_manager_:
|
||||
output.run = RunInfo(run_id=run_manager_.run_id)
|
||||
response = output
|
||||
|
||||
outputs = [
|
||||
# Get the text of the top generated string.
|
||||
{self.output_key: generation[0].text}
|
||||
for generation in response.generations
|
||||
][0]
|
||||
run_manager.on_chain_end(outputs)
|
||||
final_outputs: Dict[str, Any] = self.prep_outputs(
|
||||
inputs, outputs, return_only_outputs
|
||||
)
|
||||
if include_run_info:
|
||||
final_outputs[RUN_KEY] = RunInfo(run_id=run_manager.run_id)
|
||||
return final_outputs
|
||||
else:
|
||||
_run_manager = (
|
||||
run_manager or CallbackManagerForChainRun.get_noop_manager()
|
||||
)
|
||||
docs = inputs[self.input_key]
|
||||
# Other keys are assumed to be needed for LLM prediction
|
||||
other_keys = {
|
||||
k: v for k, v in inputs.items() if k != self.input_key
|
||||
}
|
||||
doc_strings = [
|
||||
format_document(doc, self.document_prompt) for doc in docs
|
||||
]
|
||||
# Join the documents together to put them in the prompt.
|
||||
inputs = {
|
||||
k: v
|
||||
for k, v in other_keys.items()
|
||||
if k in self.llm_chain.prompt.input_variables
|
||||
}
|
||||
inputs[self.document_variable_name] = self.document_separator.join(
|
||||
doc_strings
|
||||
)
|
||||
inputs["is_first"] = False
|
||||
inputs["input_documents"] = input_docs
|
||||
|
||||
# Call predict on the LLM.
|
||||
output = self.llm_chain(inputs, callbacks=_run_manager.get_child())
|
||||
if "process" in output.keys() and not output["process"]:
|
||||
return output
|
||||
output = output[self.llm_chain.output_key]
|
||||
extra_return_dict = {}
|
||||
extra_return_dict[self.output_key] = output
|
||||
outputs = extra_return_dict
|
||||
run_manager.on_chain_end(outputs)
|
||||
final_outputs: Dict[str, Any] = self.prep_outputs(
|
||||
inputs, outputs, return_only_outputs
|
||||
)
|
||||
if include_run_info:
|
||||
final_outputs[RUN_KEY] = RunInfo(run_id=run_manager.run_id)
|
||||
return final_outputs
|
||||
|
||||
def prep_outputs(
|
||||
self,
|
||||
inputs: Dict[str, str],
|
||||
outputs: Dict[str, str],
|
||||
return_only_outputs: bool = False,
|
||||
) -> Dict[str, str]:
|
||||
"""Validate and prep outputs."""
|
||||
self._validate_outputs(outputs)
|
||||
if self.memory is not None:
|
||||
self.memory.save_context(inputs, outputs)
|
||||
if return_only_outputs:
|
||||
return outputs
|
||||
else:
|
||||
return {**inputs, **outputs}
|
||||
|
||||
def prep_inputs(
|
||||
self, inputs: Union[Dict[str, Any], Any]
|
||||
) -> Dict[str, str]:
|
||||
"""Validate and prep inputs."""
|
||||
if not isinstance(inputs, dict):
|
||||
_input_keys = set(self.input_keys)
|
||||
if self.memory is not None:
|
||||
# If there are multiple input keys, but some get set by memory so that
|
||||
# only one is not set, we can still figure out which key it is.
|
||||
_input_keys = _input_keys.difference(
|
||||
self.memory.memory_variables
|
||||
)
|
||||
if len(_input_keys) != 1:
|
||||
raise ValueError(
|
||||
f"A single string input was passed in, but this chain expects "
|
||||
f"multiple inputs ({_input_keys}). When a chain expects "
|
||||
f"multiple inputs, please call it by passing in a dictionary, "
|
||||
"eg `chain({'foo': 1, 'bar': 2})`"
|
||||
)
|
||||
inputs = {list(_input_keys)[0]: inputs}
|
||||
if self.memory is not None:
|
||||
external_context = self.memory.load_memory_variables(inputs)
|
||||
inputs = dict(inputs, **external_context)
|
||||
self._validate_inputs(inputs)
|
||||
return inputs
|
||||
|
||||
def apply(
|
||||
self, input_list: List[Dict[str, Any]], callbacks: Callbacks = None
|
||||
) -> List[Dict[str, str]]:
|
||||
"""Call the chain on all inputs in the list."""
|
||||
return [self(inputs, callbacks=callbacks) for inputs in input_list]
|
||||
|
||||
def run(
|
||||
self,
|
||||
*args: Any,
|
||||
callbacks: Callbacks = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""Run the chain as text in, text out or multiple variables, text out."""
|
||||
if len(self.output_keys) != 1:
|
||||
raise ValueError(
|
||||
f"`run` not supported when there is not exactly "
|
||||
f"one output key. Got {self.output_keys}."
|
||||
)
|
||||
|
||||
if args and not kwargs:
|
||||
if len(args) != 1:
|
||||
raise ValueError(
|
||||
"`run` supports only one positional argument."
|
||||
)
|
||||
return self(args[0], callbacks=callbacks, tags=tags)[
|
||||
self.output_keys[0]
|
||||
]
|
||||
|
||||
if kwargs and not args:
|
||||
return self(kwargs, callbacks=callbacks, tags=tags)[
|
||||
self.output_keys[0]
|
||||
]
|
||||
|
||||
if not kwargs and not args:
|
||||
raise ValueError(
|
||||
"`run` supported with either positional arguments or keyword arguments,"
|
||||
" but none were provided."
|
||||
)
|
||||
|
||||
raise ValueError(
|
||||
f"`run` supported with either positional arguments or keyword arguments"
|
||||
f" but not both. Got args: {args} and kwargs: {kwargs}."
|
||||
)
|
||||
|
||||
def dict(self, **kwargs: Any) -> Dict:
|
||||
"""Return dictionary representation of chain."""
|
||||
if self.memory is not None:
|
||||
raise ValueError("Saving of memory is not yet supported.")
|
||||
_dict = super().dict()
|
||||
_dict["_type"] = self._chain_type
|
||||
return _dict
|
||||
|
||||
def save(self, file_path: Union[Path, str]) -> None:
|
||||
"""Save the chain.
|
||||
|
||||
Args:
|
||||
file_path: Path to file to save the chain to.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
chain.save(file_path="path/chain.yaml")
|
||||
"""
|
||||
# Convert file to Path object.
|
||||
if isinstance(file_path, str):
|
||||
save_path = Path(file_path)
|
||||
else:
|
||||
save_path = file_path
|
||||
|
||||
directory_path = save_path.parent
|
||||
directory_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Fetch dictionary to save
|
||||
chain_dict = self.dict()
|
||||
|
||||
if save_path.suffix == ".json":
|
||||
with open(file_path, "w") as f:
|
||||
json.dump(chain_dict, f, indent=4)
|
||||
elif save_path.suffix == ".yaml":
|
||||
with open(file_path, "w") as f:
|
||||
yaml.dump(chain_dict, f, default_flow_style=False)
|
||||
else:
|
||||
raise ValueError(f"{save_path} must be json or yaml")
|
||||
|
||||
|
||||
class BaseCombineDocumentsChain(Chain, ABC):
|
||||
"""Base interface for chains combining documents."""
|
||||
|
||||
@@ -496,6 +79,12 @@ class BaseCombineDocumentsChain(Chain, ABC):
|
||||
"""
|
||||
return None
|
||||
|
||||
@abstractmethod
|
||||
def combine_docs(
|
||||
self, docs: List[Document], **kwargs: Any
|
||||
) -> Tuple[str, dict]:
|
||||
"""Combine documents into a single string."""
|
||||
|
||||
def _call(
|
||||
self,
|
||||
inputs: Dict[str, List[Document]],
|
||||
@@ -507,49 +96,13 @@ class BaseCombineDocumentsChain(Chain, ABC):
|
||||
docs = inputs[self.input_key]
|
||||
# Other keys are assumed to be needed for LLM prediction
|
||||
other_keys = {k: v for k, v in inputs.items() if k != self.input_key}
|
||||
doc_strings = [
|
||||
format_document(doc, self.document_prompt) for doc in docs
|
||||
]
|
||||
# Join the documents together to put them in the prompt.
|
||||
inputs = {
|
||||
k: v
|
||||
for k, v in other_keys.items()
|
||||
if k in self.llm_chain.prompt.input_variables
|
||||
}
|
||||
inputs[self.document_variable_name] = self.document_separator.join(
|
||||
doc_strings
|
||||
output, extra_return_dict = self.combine_docs(
|
||||
docs, callbacks=_run_manager.get_child(), **other_keys
|
||||
)
|
||||
|
||||
# Call predict on the LLM.
|
||||
output, extra_return_dict = (
|
||||
self.llm_chain(inputs, callbacks=_run_manager.get_child())[
|
||||
self.llm_chain.output_key
|
||||
],
|
||||
{},
|
||||
)
|
||||
|
||||
extra_return_dict[self.output_key] = output
|
||||
return extra_return_dict
|
||||
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class Generation(Serializable):
|
||||
"""Output of a single generation."""
|
||||
|
||||
text: str
|
||||
"""Generated text output."""
|
||||
|
||||
generation_info: Optional[Dict[str, Any]] = None
|
||||
"""Raw generation info response from the provider"""
|
||||
"""May include things like reason for finishing (e.g. in OpenAI)"""
|
||||
# TODO: add log probs
|
||||
|
||||
|
||||
VALID_TASKS = ("text2text-generation", "text-generation", "summarization")
|
||||
|
||||
|
||||
class LLMChain(Chain):
|
||||
"""Chain to run queries against LLMs.
|
||||
|
||||
@@ -600,13 +153,21 @@ class LLMChain(Chain):
|
||||
inputs: Dict[str, Any],
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, str]:
|
||||
prompts, stop = self.prep_prompts([inputs], run_manager=run_manager)
|
||||
response = self.llm.generate_prompt(
|
||||
response = self.generate([inputs], run_manager=run_manager)
|
||||
return self.create_outputs(response)[0]
|
||||
|
||||
def generate(
|
||||
self,
|
||||
input_list: List[Dict[str, Any]],
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> LLMResult:
|
||||
"""Generate LLM result from inputs."""
|
||||
prompts, stop = self.prep_prompts(input_list, run_manager=run_manager)
|
||||
return self.llm.generate_prompt(
|
||||
prompts,
|
||||
stop,
|
||||
callbacks=run_manager.get_child() if run_manager else None,
|
||||
)
|
||||
return self.create_outputs(response)[0]
|
||||
|
||||
def prep_prompts(
|
||||
self,
|
||||
@@ -662,6 +223,23 @@ class LLMChain(Chain):
|
||||
for generation in response.generations
|
||||
]
|
||||
|
||||
def predict(self, callbacks: Callbacks = None, **kwargs: Any) -> str:
|
||||
"""Format prompt with kwargs and pass to LLM.
|
||||
|
||||
Args:
|
||||
callbacks: Callbacks to pass to LLMChain
|
||||
**kwargs: Keys to pass to prompt template.
|
||||
|
||||
Returns:
|
||||
Completion from LLM.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
completion = llm.predict(adjective="funny")
|
||||
"""
|
||||
return self(kwargs, callbacks=callbacks)[self.output_key]
|
||||
|
||||
def predict_and_parse(
|
||||
self, callbacks: Callbacks = None, **kwargs: Any
|
||||
) -> Union[str, List[str], Dict[str, Any]]:
|
||||
@@ -772,6 +350,14 @@ class StuffDocumentsChain(BaseCombineDocumentsChain):
|
||||
prompt = self.llm_chain.prompt.format(**inputs)
|
||||
return self.llm_chain.llm.get_num_tokens(prompt)
|
||||
|
||||
def combine_docs(
|
||||
self, docs: List[Document], callbacks: Callbacks = None, **kwargs: Any
|
||||
) -> Tuple[str, dict]:
|
||||
"""Stuff all documents into one prompt and pass to LLM."""
|
||||
inputs = self._get_inputs(docs, **kwargs)
|
||||
# Call predict on the LLM.
|
||||
return self.llm_chain.predict(callbacks=callbacks, **inputs), {}
|
||||
|
||||
@property
|
||||
def _chain_type(self) -> str:
|
||||
return "stuff_documents_chain"
|
||||
|
||||
@@ -1129,7 +1129,7 @@ class Langchain:
|
||||
max_time=max_time,
|
||||
num_return_sequences=num_return_sequences,
|
||||
)
|
||||
out = run_qa_db(
|
||||
for r in run_qa_db(
|
||||
query=instruction,
|
||||
iinput=iinput,
|
||||
context=context,
|
||||
@@ -1170,8 +1170,689 @@ class Langchain:
|
||||
auto_reduce_chunks=auto_reduce_chunks,
|
||||
max_chunks=max_chunks,
|
||||
device=self.device,
|
||||
):
|
||||
(
|
||||
outr,
|
||||
extra,
|
||||
) = r # doesn't accumulate, new answer every yield, so only save that full answer
|
||||
yield dict(response=outr, sources=extra)
|
||||
if save_dir:
|
||||
extra_dict = gen_hyper_langchain.copy()
|
||||
extra_dict.update(
|
||||
prompt_type=prompt_type,
|
||||
inference_server=inference_server,
|
||||
langchain_mode=langchain_mode,
|
||||
langchain_action=langchain_action,
|
||||
document_choice=document_choice,
|
||||
num_prompt_tokens=num_prompt_tokens,
|
||||
instruction=instruction,
|
||||
iinput=iinput,
|
||||
context=context,
|
||||
)
|
||||
save_generate_output(
|
||||
prompt=prompt,
|
||||
output=outr,
|
||||
base_model=base_model,
|
||||
save_dir=save_dir,
|
||||
where_from="run_qa_db",
|
||||
extra_dict=extra_dict,
|
||||
)
|
||||
if verbose:
|
||||
print(
|
||||
"Post-Generate Langchain: %s decoded_output: %s"
|
||||
% (str(datetime.now()), len(outr) if outr else -1),
|
||||
flush=True,
|
||||
)
|
||||
if outr or base_model in non_hf_types:
|
||||
# if got no response (e.g. not showing sources and got no sources,
|
||||
# so nothing to give to LLM), then slip through and ask LLM
|
||||
# Or if llama/gptj, then just return since they had no response and can't go down below code path
|
||||
# clear before return, since .then() never done if from API
|
||||
clear_torch_cache()
|
||||
return
|
||||
|
||||
if inference_server.startswith(
|
||||
"openai"
|
||||
) or inference_server.startswith("http"):
|
||||
if inference_server.startswith("openai"):
|
||||
import openai
|
||||
|
||||
where_from = "openai_client"
|
||||
|
||||
openai.api_key = os.getenv("OPENAI_API_KEY")
|
||||
stop_sequences = list(
|
||||
set(prompter.terminate_response + [prompter.PreResponse])
|
||||
)
|
||||
stop_sequences = [x for x in stop_sequences if x]
|
||||
# OpenAI will complain if ask for too many new tokens, takes it as min in some sense, wrongly so.
|
||||
max_new_tokens_openai = min(
|
||||
max_new_tokens, model_max_length - num_prompt_tokens
|
||||
)
|
||||
gen_server_kwargs = dict(
|
||||
temperature=temperature if do_sample else 0,
|
||||
max_tokens=max_new_tokens_openai,
|
||||
top_p=top_p if do_sample else 1,
|
||||
frequency_penalty=0,
|
||||
n=num_return_sequences,
|
||||
presence_penalty=1.07
|
||||
- repetition_penalty
|
||||
+ 0.6, # so good default
|
||||
)
|
||||
if inference_server == "openai":
|
||||
response = openai.Completion.create(
|
||||
model=base_model,
|
||||
prompt=prompt,
|
||||
**gen_server_kwargs,
|
||||
stop=stop_sequences,
|
||||
stream=stream_output,
|
||||
)
|
||||
if not stream_output:
|
||||
text = response["choices"][0]["text"]
|
||||
yield dict(
|
||||
response=prompter.get_response(
|
||||
prompt + text,
|
||||
prompt=prompt,
|
||||
sanitize_bot_response=sanitize_bot_response,
|
||||
),
|
||||
sources="",
|
||||
)
|
||||
else:
|
||||
collected_events = []
|
||||
text = ""
|
||||
for event in response:
|
||||
collected_events.append(
|
||||
event
|
||||
) # save the event response
|
||||
event_text = event["choices"][0][
|
||||
"text"
|
||||
] # extract the text
|
||||
text += event_text # append the text
|
||||
yield dict(
|
||||
response=prompter.get_response(
|
||||
prompt + text,
|
||||
prompt=prompt,
|
||||
sanitize_bot_response=sanitize_bot_response,
|
||||
),
|
||||
sources="",
|
||||
)
|
||||
elif inference_server == "openai_chat":
|
||||
response = openai.ChatCompletion.create(
|
||||
model=base_model,
|
||||
messages=[
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a helpful assistant.",
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": prompt,
|
||||
},
|
||||
],
|
||||
stream=stream_output,
|
||||
**gen_server_kwargs,
|
||||
)
|
||||
if not stream_output:
|
||||
text = response["choices"][0]["message"]["content"]
|
||||
yield dict(
|
||||
response=prompter.get_response(
|
||||
prompt + text,
|
||||
prompt=prompt,
|
||||
sanitize_bot_response=sanitize_bot_response,
|
||||
),
|
||||
sources="",
|
||||
)
|
||||
else:
|
||||
text = ""
|
||||
for chunk in response:
|
||||
delta = chunk["choices"][0]["delta"]
|
||||
if "content" in delta:
|
||||
text += delta["content"]
|
||||
yield dict(
|
||||
response=prompter.get_response(
|
||||
prompt + text,
|
||||
prompt=prompt,
|
||||
sanitize_bot_response=sanitize_bot_response,
|
||||
),
|
||||
sources="",
|
||||
)
|
||||
else:
|
||||
raise RuntimeError(
|
||||
"No such OpenAI mode: %s" % inference_server
|
||||
)
|
||||
elif inference_server.startswith("http"):
|
||||
inference_server, headers = get_hf_server(inference_server)
|
||||
from gradio_utils.grclient import GradioClient
|
||||
from text_generation import Client as HFClient
|
||||
|
||||
if isinstance(model, GradioClient):
|
||||
gr_client = model
|
||||
hf_client = None
|
||||
elif isinstance(model, HFClient):
|
||||
gr_client = None
|
||||
hf_client = model
|
||||
else:
|
||||
(
|
||||
inference_server,
|
||||
gr_client,
|
||||
hf_client,
|
||||
) = self.get_client_from_inference_server(
|
||||
inference_server, base_model=base_model
|
||||
)
|
||||
|
||||
# quick sanity check to avoid long timeouts, just see if can reach server
|
||||
requests.get(
|
||||
inference_server,
|
||||
timeout=int(os.getenv("REQUEST_TIMEOUT_FAST", "10")),
|
||||
)
|
||||
|
||||
if gr_client is not None:
|
||||
# Note: h2oGPT gradio server could handle input token size issues for prompt,
|
||||
# but best to handle here so send less data to server
|
||||
|
||||
chat_client = False
|
||||
where_from = "gr_client"
|
||||
client_langchain_mode = "Disabled"
|
||||
client_langchain_action = LangChainAction.QUERY.value
|
||||
gen_server_kwargs = dict(
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
top_k=top_k,
|
||||
num_beams=num_beams,
|
||||
max_new_tokens=max_new_tokens,
|
||||
min_new_tokens=min_new_tokens,
|
||||
early_stopping=early_stopping,
|
||||
max_time=max_time,
|
||||
repetition_penalty=repetition_penalty,
|
||||
num_return_sequences=num_return_sequences,
|
||||
do_sample=do_sample,
|
||||
chat=chat_client,
|
||||
)
|
||||
# account for gradio into gradio that handles prompting, avoid duplicating prompter prompt injection
|
||||
if prompt_type in [
|
||||
None,
|
||||
"",
|
||||
PromptType.plain.name,
|
||||
PromptType.plain.value,
|
||||
str(PromptType.plain.value),
|
||||
]:
|
||||
# if our prompt is plain, assume either correct or gradio server knows different prompt type,
|
||||
# so pass empty prompt_Type
|
||||
gr_prompt_type = ""
|
||||
gr_prompt_dict = ""
|
||||
gr_prompt = prompt # already prepared prompt
|
||||
gr_context = ""
|
||||
gr_iinput = ""
|
||||
else:
|
||||
# if already have prompt_type that is not plain, None, or '', then already applied some prompting
|
||||
# But assume server can handle prompting, and need to avoid double-up.
|
||||
# Also assume server can do better job of using stopping.py to stop early, so avoid local prompting, let server handle
|
||||
# So avoid "prompt" and let gradio server reconstruct from prompt_type we passed
|
||||
# Note it's ok that prompter.get_response() has prompt+text, prompt=prompt passed,
|
||||
# because just means extra processing and removal of prompt, but that has no human-bot prompting doesn't matter
|
||||
# since those won't appear
|
||||
gr_context = context
|
||||
gr_prompt = instruction
|
||||
gr_iinput = iinput
|
||||
gr_prompt_type = prompt_type
|
||||
gr_prompt_dict = prompt_dict
|
||||
client_kwargs = dict(
|
||||
instruction=gr_prompt
|
||||
if chat_client
|
||||
else "", # only for chat=True
|
||||
iinput=gr_iinput, # only for chat=True
|
||||
context=gr_context,
|
||||
# streaming output is supported, loops over and outputs each generation in streaming mode
|
||||
# but leave stream_output=False for simple input/output mode
|
||||
stream_output=stream_output,
|
||||
**gen_server_kwargs,
|
||||
prompt_type=gr_prompt_type,
|
||||
prompt_dict=gr_prompt_dict,
|
||||
instruction_nochat=gr_prompt
|
||||
if not chat_client
|
||||
else "",
|
||||
iinput_nochat=gr_iinput, # only for chat=False
|
||||
langchain_mode=client_langchain_mode,
|
||||
langchain_action=client_langchain_action,
|
||||
top_k_docs=top_k_docs,
|
||||
chunk=chunk,
|
||||
chunk_size=chunk_size,
|
||||
document_choice=[DocumentChoices.All_Relevant.name],
|
||||
)
|
||||
api_name = "/submit_nochat_api" # NOTE: like submit_nochat but stable API for string dict passing
|
||||
if not stream_output:
|
||||
res = gr_client.predict(
|
||||
str(dict(client_kwargs)), api_name=api_name
|
||||
)
|
||||
res_dict = ast.literal_eval(res)
|
||||
text = res_dict["response"]
|
||||
sources = res_dict["sources"]
|
||||
yield dict(
|
||||
response=prompter.get_response(
|
||||
prompt + text,
|
||||
prompt=prompt,
|
||||
sanitize_bot_response=sanitize_bot_response,
|
||||
),
|
||||
sources=sources,
|
||||
)
|
||||
else:
|
||||
job = gr_client.submit(
|
||||
str(dict(client_kwargs)), api_name=api_name
|
||||
)
|
||||
text = ""
|
||||
sources = ""
|
||||
res_dict = dict(response=text, sources=sources)
|
||||
while not job.done():
|
||||
outputs_list = job.communicator.job.outputs
|
||||
if outputs_list:
|
||||
res = job.communicator.job.outputs[-1]
|
||||
res_dict = ast.literal_eval(res)
|
||||
text = res_dict["response"]
|
||||
sources = res_dict["sources"]
|
||||
if gr_prompt_type == "plain":
|
||||
# then gradio server passes back full prompt + text
|
||||
prompt_and_text = text
|
||||
else:
|
||||
prompt_and_text = prompt + text
|
||||
yield dict(
|
||||
response=prompter.get_response(
|
||||
prompt_and_text,
|
||||
prompt=prompt,
|
||||
sanitize_bot_response=sanitize_bot_response,
|
||||
),
|
||||
sources=sources,
|
||||
)
|
||||
time.sleep(0.01)
|
||||
# ensure get last output to avoid race
|
||||
res_all = job.outputs()
|
||||
if len(res_all) > 0:
|
||||
res = res_all[-1]
|
||||
res_dict = ast.literal_eval(res)
|
||||
text = res_dict["response"]
|
||||
sources = res_dict["sources"]
|
||||
else:
|
||||
# go with old text if last call didn't work
|
||||
e = job.future._exception
|
||||
if e is not None:
|
||||
stre = str(e)
|
||||
strex = "".join(
|
||||
traceback.format_tb(e.__traceback__)
|
||||
)
|
||||
else:
|
||||
stre = ""
|
||||
strex = ""
|
||||
|
||||
print(
|
||||
"Bad final response: %s %s %s %s %s: %s %s"
|
||||
% (
|
||||
base_model,
|
||||
inference_server,
|
||||
res_all,
|
||||
prompt,
|
||||
text,
|
||||
stre,
|
||||
strex,
|
||||
),
|
||||
flush=True,
|
||||
)
|
||||
if gr_prompt_type == "plain":
|
||||
# then gradio server passes back full prompt + text
|
||||
prompt_and_text = text
|
||||
else:
|
||||
prompt_and_text = prompt + text
|
||||
yield dict(
|
||||
response=prompter.get_response(
|
||||
prompt_and_text,
|
||||
prompt=prompt,
|
||||
sanitize_bot_response=sanitize_bot_response,
|
||||
),
|
||||
sources=sources,
|
||||
)
|
||||
elif hf_client:
|
||||
# HF inference server needs control over input tokens
|
||||
where_from = "hf_client"
|
||||
|
||||
# prompt must include all human-bot like tokens, already added by prompt
|
||||
# https://github.com/huggingface/text-generation-inference/tree/main/clients/python#types
|
||||
stop_sequences = list(
|
||||
set(
|
||||
prompter.terminate_response
|
||||
+ [prompter.PreResponse]
|
||||
)
|
||||
)
|
||||
stop_sequences = [x for x in stop_sequences if x]
|
||||
gen_server_kwargs = dict(
|
||||
do_sample=do_sample,
|
||||
max_new_tokens=max_new_tokens,
|
||||
# best_of=None,
|
||||
repetition_penalty=repetition_penalty,
|
||||
return_full_text=True,
|
||||
seed=SEED,
|
||||
stop_sequences=stop_sequences,
|
||||
temperature=temperature,
|
||||
top_k=top_k,
|
||||
top_p=top_p,
|
||||
# truncate=False, # behaves oddly
|
||||
# typical_p=top_p,
|
||||
# watermark=False,
|
||||
# decoder_input_details=False,
|
||||
)
|
||||
# work-around for timeout at constructor time, will be issue if multi-threading,
|
||||
# so just do something reasonable or max_time if larger
|
||||
# lower bound because client is re-used if multi-threading
|
||||
hf_client.timeout = max(300, max_time)
|
||||
if not stream_output:
|
||||
text = hf_client.generate(
|
||||
prompt, **gen_server_kwargs
|
||||
).generated_text
|
||||
yield dict(
|
||||
response=prompter.get_response(
|
||||
text,
|
||||
prompt=prompt,
|
||||
sanitize_bot_response=sanitize_bot_response,
|
||||
),
|
||||
sources="",
|
||||
)
|
||||
else:
|
||||
text = ""
|
||||
for response in hf_client.generate_stream(
|
||||
prompt, **gen_server_kwargs
|
||||
):
|
||||
if not response.token.special:
|
||||
# stop_sequences
|
||||
text_chunk = response.token.text
|
||||
text += text_chunk
|
||||
yield dict(
|
||||
response=prompter.get_response(
|
||||
prompt + text,
|
||||
prompt=prompt,
|
||||
sanitize_bot_response=sanitize_bot_response,
|
||||
),
|
||||
sources="",
|
||||
)
|
||||
else:
|
||||
raise RuntimeError(
|
||||
"Failed to get client: %s" % inference_server
|
||||
)
|
||||
else:
|
||||
raise RuntimeError(
|
||||
"No such inference_server %s" % inference_server
|
||||
)
|
||||
|
||||
if save_dir and text:
|
||||
# save prompt + new text
|
||||
extra_dict = gen_server_kwargs.copy()
|
||||
extra_dict.update(
|
||||
dict(
|
||||
inference_server=inference_server,
|
||||
num_prompt_tokens=num_prompt_tokens,
|
||||
)
|
||||
)
|
||||
save_generate_output(
|
||||
prompt=prompt,
|
||||
output=text,
|
||||
base_model=base_model,
|
||||
save_dir=save_dir,
|
||||
where_from=where_from,
|
||||
extra_dict=extra_dict,
|
||||
)
|
||||
return
|
||||
else:
|
||||
assert not inference_server, (
|
||||
"inferene_server=%s not supported" % inference_server
|
||||
)
|
||||
return out
|
||||
|
||||
if isinstance(tokenizer, str):
|
||||
# pipeline
|
||||
if tokenizer == "summarization":
|
||||
key = "summary_text"
|
||||
else:
|
||||
raise RuntimeError("No such task type %s" % tokenizer)
|
||||
# NOTE: uses max_length only
|
||||
yield dict(
|
||||
response=model(prompt, max_length=max_new_tokens)[0][key],
|
||||
sources="",
|
||||
)
|
||||
|
||||
if "mbart-" in base_model.lower():
|
||||
assert src_lang is not None
|
||||
tokenizer.src_lang = self.languages_covered()[src_lang]
|
||||
|
||||
stopping_criteria = get_stopping(
|
||||
prompt_type,
|
||||
prompt_dict,
|
||||
tokenizer,
|
||||
self.device,
|
||||
model_max_length=tokenizer.model_max_length,
|
||||
)
|
||||
|
||||
print(prompt)
|
||||
# exit(0)
|
||||
inputs = tokenizer(prompt, return_tensors="pt")
|
||||
if debug and len(inputs["input_ids"]) > 0:
|
||||
print("input_ids length", len(inputs["input_ids"][0]), flush=True)
|
||||
input_ids = inputs["input_ids"].to(self.device)
|
||||
# CRITICAL LIMIT else will fail
|
||||
max_max_tokens = tokenizer.model_max_length
|
||||
max_input_tokens = max_max_tokens - min_new_tokens
|
||||
# NOTE: Don't limit up front due to max_new_tokens, let go up to max or reach max_max_tokens in stopping.py
|
||||
input_ids = input_ids[:, -max_input_tokens:]
|
||||
# required for falcon if multiple threads or asyncio accesses to model during generation
|
||||
if use_cache is None:
|
||||
use_cache = False if "falcon" in base_model else True
|
||||
gen_config_kwargs = dict(
|
||||
temperature=float(temperature),
|
||||
top_p=float(top_p),
|
||||
top_k=top_k,
|
||||
num_beams=num_beams,
|
||||
do_sample=do_sample,
|
||||
repetition_penalty=float(repetition_penalty),
|
||||
num_return_sequences=num_return_sequences,
|
||||
renormalize_logits=True,
|
||||
remove_invalid_values=True,
|
||||
use_cache=use_cache,
|
||||
)
|
||||
token_ids = [
|
||||
"eos_token_id",
|
||||
"pad_token_id",
|
||||
"bos_token_id",
|
||||
"cls_token_id",
|
||||
"sep_token_id",
|
||||
]
|
||||
for token_id in token_ids:
|
||||
if (
|
||||
hasattr(tokenizer, token_id)
|
||||
and getattr(tokenizer, token_id) is not None
|
||||
):
|
||||
gen_config_kwargs.update(
|
||||
{token_id: getattr(tokenizer, token_id)}
|
||||
)
|
||||
generation_config = GenerationConfig(**gen_config_kwargs)
|
||||
|
||||
gen_kwargs = dict(
|
||||
input_ids=input_ids,
|
||||
generation_config=generation_config,
|
||||
return_dict_in_generate=True,
|
||||
output_scores=True,
|
||||
max_new_tokens=max_new_tokens, # prompt + new
|
||||
min_new_tokens=min_new_tokens, # prompt + new
|
||||
early_stopping=early_stopping, # False, True, "never"
|
||||
max_time=max_time,
|
||||
stopping_criteria=stopping_criteria,
|
||||
)
|
||||
if "gpt2" in base_model.lower():
|
||||
gen_kwargs.update(
|
||||
dict(
|
||||
bos_token_id=tokenizer.bos_token_id,
|
||||
pad_token_id=tokenizer.eos_token_id,
|
||||
)
|
||||
)
|
||||
elif "mbart-" in base_model.lower():
|
||||
assert tgt_lang is not None
|
||||
tgt_lang = self.languages_covered()[tgt_lang]
|
||||
gen_kwargs.update(
|
||||
dict(forced_bos_token_id=tokenizer.lang_code_to_id[tgt_lang])
|
||||
)
|
||||
else:
|
||||
token_ids = ["eos_token_id", "bos_token_id", "pad_token_id"]
|
||||
for token_id in token_ids:
|
||||
if (
|
||||
hasattr(tokenizer, token_id)
|
||||
and getattr(tokenizer, token_id) is not None
|
||||
):
|
||||
gen_kwargs.update({token_id: getattr(tokenizer, token_id)})
|
||||
|
||||
decoder_kwargs = dict(
|
||||
skip_special_tokens=True, clean_up_tokenization_spaces=True
|
||||
)
|
||||
|
||||
decoder = functools.partial(tokenizer.decode, **decoder_kwargs)
|
||||
decoder_raw_kwargs = dict(
|
||||
skip_special_tokens=False, clean_up_tokenization_spaces=True
|
||||
)
|
||||
|
||||
decoder_raw = functools.partial(tokenizer.decode, **decoder_raw_kwargs)
|
||||
|
||||
with torch.no_grad():
|
||||
have_lora_weights = lora_weights not in [no_lora_str, "", None]
|
||||
context_class_cast = (
|
||||
NullContext
|
||||
if self.device == "cpu"
|
||||
or have_lora_weights
|
||||
or self.device == "mps"
|
||||
else torch.autocast
|
||||
)
|
||||
with context_class_cast(self.device):
|
||||
# protection for gradio not keeping track of closed users,
|
||||
# else hit bitsandbytes lack of thread safety:
|
||||
# https://github.com/h2oai/h2ogpt/issues/104
|
||||
# but only makes sense if concurrency_count == 1
|
||||
context_class = NullContext # if concurrency_count > 1 else filelock.FileLock
|
||||
if verbose:
|
||||
print("Pre-Generate: %s" % str(datetime.now()), flush=True)
|
||||
decoded_output = None
|
||||
with context_class("generate.lock"):
|
||||
if verbose:
|
||||
print("Generate: %s" % str(datetime.now()), flush=True)
|
||||
# decoded tokenized prompt can deviate from prompt due to special characters
|
||||
inputs_decoded = decoder(input_ids[0])
|
||||
inputs_decoded_raw = decoder_raw(input_ids[0])
|
||||
if inputs_decoded == prompt:
|
||||
# normal
|
||||
pass
|
||||
elif inputs_decoded.lstrip() == prompt.lstrip():
|
||||
# sometimes extra space in front, make prompt same for prompt removal
|
||||
prompt = inputs_decoded
|
||||
elif inputs_decoded_raw == prompt:
|
||||
# some models specify special tokens that are part of normal prompt, so can't skip them
|
||||
inputs_decoded = prompt = inputs_decoded_raw
|
||||
decoder = decoder_raw
|
||||
decoder_kwargs = decoder_raw_kwargs
|
||||
elif inputs_decoded_raw.replace("<unk> ", "").replace(
|
||||
"<unk>", ""
|
||||
).replace("\n", " ").replace(" ", "") == prompt.replace(
|
||||
"\n", " "
|
||||
).replace(
|
||||
" ", ""
|
||||
):
|
||||
inputs_decoded = prompt = inputs_decoded_raw
|
||||
decoder = decoder_raw
|
||||
decoder_kwargs = decoder_raw_kwargs
|
||||
else:
|
||||
if verbose:
|
||||
print(
|
||||
"WARNING: Special characters in prompt",
|
||||
flush=True,
|
||||
)
|
||||
if stream_output:
|
||||
skip_prompt = False
|
||||
streamer = H2OTextIteratorStreamer(
|
||||
tokenizer,
|
||||
skip_prompt=skip_prompt,
|
||||
block=False,
|
||||
**decoder_kwargs,
|
||||
)
|
||||
gen_kwargs.update(dict(streamer=streamer))
|
||||
target = wrapped_partial(
|
||||
self.generate_with_exceptions,
|
||||
model.generate,
|
||||
prompt=prompt,
|
||||
inputs_decoded=inputs_decoded,
|
||||
raise_generate_gpu_exceptions=raise_generate_gpu_exceptions,
|
||||
**gen_kwargs,
|
||||
)
|
||||
bucket = queue.Queue()
|
||||
thread = EThread(
|
||||
target=target, streamer=streamer, bucket=bucket
|
||||
)
|
||||
thread.start()
|
||||
outputs = ""
|
||||
try:
|
||||
for new_text in streamer:
|
||||
if bucket.qsize() > 0 or thread.exc:
|
||||
thread.join()
|
||||
outputs += new_text
|
||||
yield dict(
|
||||
response=prompter.get_response(
|
||||
outputs,
|
||||
prompt=inputs_decoded,
|
||||
sanitize_bot_response=sanitize_bot_response,
|
||||
),
|
||||
sources="",
|
||||
)
|
||||
except BaseException:
|
||||
# if any exception, raise that exception if was from thread, first
|
||||
if thread.exc:
|
||||
raise thread.exc
|
||||
raise
|
||||
finally:
|
||||
# clear before return, since .then() never done if from API
|
||||
clear_torch_cache()
|
||||
# in case no exception and didn't join with thread yet, then join
|
||||
if not thread.exc:
|
||||
thread.join()
|
||||
# in case raise StopIteration or broke queue loop in streamer, but still have exception
|
||||
if thread.exc:
|
||||
raise thread.exc
|
||||
decoded_output = outputs
|
||||
else:
|
||||
try:
|
||||
outputs = model.generate(**gen_kwargs)
|
||||
finally:
|
||||
clear_torch_cache() # has to be here for API submit_nochat_api since.then() not called
|
||||
outputs = [decoder(s) for s in outputs.sequences]
|
||||
yield dict(
|
||||
response=prompter.get_response(
|
||||
outputs,
|
||||
prompt=inputs_decoded,
|
||||
sanitize_bot_response=sanitize_bot_response,
|
||||
),
|
||||
sources="",
|
||||
)
|
||||
if outputs and len(outputs) >= 1:
|
||||
decoded_output = prompt + outputs[0]
|
||||
if save_dir and decoded_output:
|
||||
extra_dict = gen_config_kwargs.copy()
|
||||
extra_dict.update(
|
||||
dict(num_prompt_tokens=num_prompt_tokens)
|
||||
)
|
||||
save_generate_output(
|
||||
prompt=prompt,
|
||||
output=decoded_output,
|
||||
base_model=base_model,
|
||||
save_dir=save_dir,
|
||||
where_from="evaluate_%s" % str(stream_output),
|
||||
extra_dict=gen_config_kwargs,
|
||||
)
|
||||
if verbose:
|
||||
print(
|
||||
"Post-Generate: %s decoded_output: %s"
|
||||
% (
|
||||
str(datetime.now()),
|
||||
len(decoded_output) if decoded_output else -1,
|
||||
),
|
||||
flush=True,
|
||||
)
|
||||
return outputs[0]
|
||||
|
||||
inputs_list_names = list(inspect.signature(evaluate).parameters)
|
||||
global inputs_kwargs_list
|
||||
|
||||
@@ -436,7 +436,7 @@ class GradioInference(LLM):
|
||||
chat_client: bool = False
|
||||
|
||||
return_full_text: bool = True
|
||||
stream_output: bool = Field(False, alias="stream")
|
||||
stream: bool = False
|
||||
sanitize_bot_response: bool = False
|
||||
|
||||
prompter: Any = None
|
||||
@@ -481,7 +481,7 @@ class GradioInference(LLM):
|
||||
# so server should get prompt_type or '', not plain
|
||||
# This is good, so gradio server can also handle stopping.py conditions
|
||||
# this is different than TGI server that uses prompter to inject prompt_type prompting
|
||||
stream_output = self.stream_output
|
||||
stream_output = self.stream
|
||||
gr_client = self.client
|
||||
client_langchain_mode = "Disabled"
|
||||
client_langchain_action = LangChainAction.QUERY.value
|
||||
@@ -596,7 +596,7 @@ class H2OHuggingFaceTextGenInference(HuggingFaceTextGenInference):
|
||||
inference_server_url: str = ""
|
||||
timeout: int = 300
|
||||
headers: dict = None
|
||||
stream_output: bool = Field(False, alias="stream")
|
||||
stream: bool = False
|
||||
sanitize_bot_response: bool = False
|
||||
prompter: Any = None
|
||||
tokenizer: Any = None
|
||||
@@ -663,7 +663,7 @@ class H2OHuggingFaceTextGenInference(HuggingFaceTextGenInference):
|
||||
# lower bound because client is re-used if multi-threading
|
||||
self.client.timeout = max(300, self.timeout)
|
||||
|
||||
if not self.stream_output:
|
||||
if not self.stream:
|
||||
res = self.client.generate(
|
||||
prompt,
|
||||
**gen_server_kwargs,
|
||||
@@ -852,7 +852,7 @@ def get_llm(
|
||||
top_p=top_p,
|
||||
# typical_p=top_p,
|
||||
callbacks=callbacks if stream_output else None,
|
||||
stream_output=stream_output,
|
||||
stream=stream_output,
|
||||
prompter=prompter,
|
||||
tokenizer=tokenizer,
|
||||
client=hf_client,
|
||||
@@ -2510,7 +2510,8 @@ def _run_qa_db(
|
||||
formatted_doc_chunks = "\n\n".join(
|
||||
[get_url(x) + "\n\n" + x.page_content for x in docs]
|
||||
)
|
||||
return formatted_doc_chunks, ""
|
||||
yield formatted_doc_chunks, ""
|
||||
return
|
||||
if not docs and langchain_action in [
|
||||
LangChainAction.SUMMARIZE_MAP.value,
|
||||
LangChainAction.SUMMARIZE_ALL.value,
|
||||
@@ -2522,7 +2523,8 @@ def _run_qa_db(
|
||||
else "No documents to summarize."
|
||||
)
|
||||
extra = ""
|
||||
return ret, extra
|
||||
yield ret, extra
|
||||
return
|
||||
if not docs and langchain_mode not in [
|
||||
LangChainMode.DISABLED.value,
|
||||
LangChainMode.CHAT_LLM.value,
|
||||
@@ -2534,7 +2536,8 @@ def _run_qa_db(
|
||||
else "No documents to query."
|
||||
)
|
||||
extra = ""
|
||||
return ret, extra
|
||||
yield ret, extra
|
||||
return
|
||||
|
||||
if chain is None and model_name not in non_hf_types:
|
||||
# here if no docs at all and not HF type
|
||||
@@ -2554,7 +2557,22 @@ def _run_qa_db(
|
||||
)
|
||||
with context_class_cast(args.device):
|
||||
answer = chain()
|
||||
return answer
|
||||
|
||||
if not use_context:
|
||||
ret = answer["output_text"]
|
||||
extra = ""
|
||||
yield ret, extra
|
||||
elif answer is not None:
|
||||
ret, extra = get_sources_answer(
|
||||
query,
|
||||
answer,
|
||||
scores,
|
||||
show_rank,
|
||||
answer_with_sources,
|
||||
verbose=verbose,
|
||||
)
|
||||
yield ret, extra
|
||||
return
|
||||
|
||||
|
||||
def get_similarity_chain(
|
||||
|
||||
@@ -3,11 +3,13 @@ from apps.stable_diffusion.src.utils.utils import _compile_module
|
||||
from io import BytesIO
|
||||
import torch_mlir
|
||||
|
||||
from transformers import TextGenerationPipeline
|
||||
from transformers.pipelines.text_generation import ReturnType
|
||||
|
||||
from stopping import get_stopping
|
||||
from prompter import Prompter, PromptType
|
||||
|
||||
from transformers import TextGenerationPipeline
|
||||
from transformers.pipelines.text_generation import ReturnType
|
||||
|
||||
from transformers.generation import (
|
||||
GenerationConfig,
|
||||
LogitsProcessorList,
|
||||
@@ -29,8 +31,14 @@ from brevitas_examples.llm.llm_quant.quantize import quantize_model
|
||||
from brevitas_examples.llm.llm_quant.run_utils import get_model_impl
|
||||
|
||||
|
||||
# fmt: off
|
||||
def quant〇matmul_rhs_group_quant〡shape(lhs: List[int], rhs: List[int], rhs_scale: List[int], rhs_zero_point: List[int], rhs_bit_width: int, rhs_group_size: int) -> List[int]:
|
||||
def brevitas〇matmul_rhs_group_quant〡shape(
|
||||
lhs: List[int],
|
||||
rhs: List[int],
|
||||
rhs_scale: List[int],
|
||||
rhs_zero_point: List[int],
|
||||
rhs_bit_width: int,
|
||||
rhs_group_size: int,
|
||||
) -> List[int]:
|
||||
if len(lhs) == 3 and len(rhs) == 2:
|
||||
return [lhs[0], lhs[1], rhs[0]]
|
||||
elif len(lhs) == 2 and len(rhs) == 2:
|
||||
@@ -39,21 +47,30 @@ def quant〇matmul_rhs_group_quant〡shape(lhs: List[int], rhs: List[int], rhs_s
|
||||
raise ValueError("Input shapes not supported.")
|
||||
|
||||
|
||||
def quant〇matmul_rhs_group_quant〡dtype(lhs_rank_dtype: Tuple[int, int], rhs_rank_dtype: Tuple[int, int], rhs_scale_rank_dtype: Tuple[int, int], rhs_zero_point_rank_dtype: Tuple[int, int], rhs_bit_width: int, rhs_group_size: int) -> int:
|
||||
def brevitas〇matmul_rhs_group_quant〡dtype(
|
||||
lhs_rank_dtype: Tuple[int, int],
|
||||
rhs_rank_dtype: Tuple[int, int],
|
||||
rhs_scale_rank_dtype: Tuple[int, int],
|
||||
rhs_zero_point_rank_dtype: Tuple[int, int],
|
||||
rhs_bit_width: int,
|
||||
rhs_group_size: int,
|
||||
) -> int:
|
||||
# output dtype is the dtype of the lhs float input
|
||||
lhs_rank, lhs_dtype = lhs_rank_dtype
|
||||
return lhs_dtype
|
||||
|
||||
|
||||
def quant〇matmul_rhs_group_quant〡has_value_semantics(lhs, rhs, rhs_scale, rhs_zero_point, rhs_bit_width, rhs_group_size) -> None:
|
||||
def brevitas〇matmul_rhs_group_quant〡has_value_semantics(
|
||||
lhs, rhs, rhs_scale, rhs_zero_point, rhs_bit_width, rhs_group_size
|
||||
) -> None:
|
||||
return
|
||||
|
||||
|
||||
brevitas_matmul_rhs_group_quant_library = [
|
||||
quant〇matmul_rhs_group_quant〡shape,
|
||||
quant〇matmul_rhs_group_quant〡dtype,
|
||||
quant〇matmul_rhs_group_quant〡has_value_semantics]
|
||||
# fmt: on
|
||||
brevitas〇matmul_rhs_group_quant〡shape,
|
||||
brevitas〇matmul_rhs_group_quant〡dtype,
|
||||
brevitas〇matmul_rhs_group_quant〡has_value_semantics,
|
||||
]
|
||||
|
||||
global_device = "cuda"
|
||||
global_precision = "fp16"
|
||||
@@ -229,7 +246,7 @@ class H2OGPTSHARKModel(torch.nn.Module):
|
||||
ts_graph,
|
||||
[*h2ogptCompileInput],
|
||||
output_type=torch_mlir.OutputType.TORCH,
|
||||
backend_legal_ops=["quant.matmul_rhs_group_quant"],
|
||||
backend_legal_ops=["brevitas.matmul_rhs_group_quant"],
|
||||
extra_library=brevitas_matmul_rhs_group_quant_library,
|
||||
use_tracing=False,
|
||||
verbose=False,
|
||||
@@ -268,215 +285,7 @@ class H2OGPTSHARKModel(torch.nn.Module):
|
||||
return result
|
||||
|
||||
|
||||
def decode_tokens(tokenizer, res_tokens):
|
||||
for i in range(len(res_tokens)):
|
||||
if type(res_tokens[i]) != int:
|
||||
res_tokens[i] = int(res_tokens[i][0])
|
||||
|
||||
res_str = tokenizer.decode(res_tokens, skip_special_tokens=True)
|
||||
return res_str
|
||||
|
||||
|
||||
def generate_token(h2ogpt_shark_model, model, tokenizer, **generate_kwargs):
|
||||
del generate_kwargs["max_time"]
|
||||
generate_kwargs["input_ids"] = generate_kwargs["input_ids"].to(
|
||||
device=tensor_device
|
||||
)
|
||||
generate_kwargs["attention_mask"] = generate_kwargs["attention_mask"].to(
|
||||
device=tensor_device
|
||||
)
|
||||
truncated_input_ids = []
|
||||
stopping_criteria = generate_kwargs["stopping_criteria"]
|
||||
|
||||
generation_config_ = GenerationConfig.from_model_config(model.config)
|
||||
generation_config = copy.deepcopy(generation_config_)
|
||||
model_kwargs = generation_config.update(**generate_kwargs)
|
||||
|
||||
logits_processor = LogitsProcessorList()
|
||||
stopping_criteria = (
|
||||
stopping_criteria
|
||||
if stopping_criteria is not None
|
||||
else StoppingCriteriaList()
|
||||
)
|
||||
|
||||
eos_token_id = generation_config.eos_token_id
|
||||
generation_config.pad_token_id = eos_token_id
|
||||
|
||||
(
|
||||
inputs_tensor,
|
||||
model_input_name,
|
||||
model_kwargs,
|
||||
) = model._prepare_model_inputs(
|
||||
None, generation_config.bos_token_id, model_kwargs
|
||||
)
|
||||
|
||||
model_kwargs["output_attentions"] = generation_config.output_attentions
|
||||
model_kwargs[
|
||||
"output_hidden_states"
|
||||
] = generation_config.output_hidden_states
|
||||
model_kwargs["use_cache"] = generation_config.use_cache
|
||||
|
||||
input_ids = (
|
||||
inputs_tensor
|
||||
if model_input_name == "input_ids"
|
||||
else model_kwargs.pop("input_ids")
|
||||
)
|
||||
|
||||
input_ids_seq_length = input_ids.shape[-1]
|
||||
|
||||
generation_config.max_length = (
|
||||
generation_config.max_new_tokens + input_ids_seq_length
|
||||
)
|
||||
|
||||
logits_processor = model._get_logits_processor(
|
||||
generation_config=generation_config,
|
||||
input_ids_seq_length=input_ids_seq_length,
|
||||
encoder_input_ids=inputs_tensor,
|
||||
prefix_allowed_tokens_fn=None,
|
||||
logits_processor=logits_processor,
|
||||
)
|
||||
|
||||
stopping_criteria = model._get_stopping_criteria(
|
||||
generation_config=generation_config,
|
||||
stopping_criteria=stopping_criteria,
|
||||
)
|
||||
|
||||
logits_warper = model._get_logits_warper(generation_config)
|
||||
|
||||
(
|
||||
input_ids,
|
||||
model_kwargs,
|
||||
) = model._expand_inputs_for_generation(
|
||||
input_ids=input_ids,
|
||||
expand_size=generation_config.num_return_sequences, # 1
|
||||
is_encoder_decoder=model.config.is_encoder_decoder, # False
|
||||
**model_kwargs,
|
||||
)
|
||||
|
||||
if isinstance(eos_token_id, int):
|
||||
eos_token_id = [eos_token_id]
|
||||
eos_token_id_tensor = (
|
||||
torch.tensor(eos_token_id).to(device=tensor_device)
|
||||
if eos_token_id is not None
|
||||
else None
|
||||
)
|
||||
|
||||
pad_token_id = generation_config.pad_token_id
|
||||
eos_token_id = eos_token_id
|
||||
|
||||
output_scores = generation_config.output_scores # False
|
||||
return_dict_in_generate = (
|
||||
generation_config.return_dict_in_generate # False
|
||||
)
|
||||
|
||||
# init attention / hidden states / scores tuples
|
||||
scores = () if (return_dict_in_generate and output_scores) else None
|
||||
|
||||
# keep track of which sequences are already finished
|
||||
unfinished_sequences = torch.ones(
|
||||
input_ids.shape[0],
|
||||
dtype=torch.long,
|
||||
device=input_ids.device,
|
||||
)
|
||||
|
||||
timesRan = 0
|
||||
import time
|
||||
|
||||
start = time.time()
|
||||
print("\n")
|
||||
|
||||
res_tokens = []
|
||||
while True:
|
||||
model_inputs = model.prepare_inputs_for_generation(
|
||||
input_ids, **model_kwargs
|
||||
)
|
||||
|
||||
outputs = h2ogpt_shark_model.forward(
|
||||
model_inputs["input_ids"], model_inputs["attention_mask"]
|
||||
)
|
||||
|
||||
if args.precision == "fp16":
|
||||
outputs = outputs.to(dtype=torch.float32)
|
||||
next_token_logits = outputs
|
||||
|
||||
# pre-process distribution
|
||||
next_token_scores = logits_processor(input_ids, next_token_logits)
|
||||
next_token_scores = logits_warper(input_ids, next_token_scores)
|
||||
|
||||
# sample
|
||||
probs = torch.nn.functional.softmax(next_token_scores, dim=-1)
|
||||
|
||||
next_token = torch.multinomial(probs, num_samples=1).squeeze(1)
|
||||
|
||||
# finished sentences should have their next token be a padding token
|
||||
if eos_token_id is not None:
|
||||
if pad_token_id is None:
|
||||
raise ValueError(
|
||||
"If `eos_token_id` is defined, make sure that `pad_token_id` is defined."
|
||||
)
|
||||
next_token = next_token * unfinished_sequences + pad_token_id * (
|
||||
1 - unfinished_sequences
|
||||
)
|
||||
|
||||
input_ids = torch.cat([input_ids, next_token[:, None]], dim=-1)
|
||||
|
||||
model_kwargs["past_key_values"] = None
|
||||
if "attention_mask" in model_kwargs:
|
||||
attention_mask = model_kwargs["attention_mask"]
|
||||
model_kwargs["attention_mask"] = torch.cat(
|
||||
[
|
||||
attention_mask,
|
||||
attention_mask.new_ones((attention_mask.shape[0], 1)),
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
|
||||
truncated_input_ids.append(input_ids[:, 0])
|
||||
input_ids = input_ids[:, 1:]
|
||||
model_kwargs["attention_mask"] = model_kwargs["attention_mask"][:, 1:]
|
||||
|
||||
new_word = tokenizer.decode(
|
||||
next_token.cpu().numpy(),
|
||||
add_special_tokens=False,
|
||||
skip_special_tokens=True,
|
||||
clean_up_tokenization_spaces=True,
|
||||
)
|
||||
|
||||
res_tokens.append(next_token)
|
||||
if new_word == "<0x0A>":
|
||||
print("\n", end="", flush=True)
|
||||
else:
|
||||
print(f"{new_word}", end=" ", flush=True)
|
||||
|
||||
part_str = decode_tokens(tokenizer, res_tokens)
|
||||
yield part_str
|
||||
|
||||
# if eos_token was found in one sentence, set sentence to finished
|
||||
if eos_token_id_tensor is not None:
|
||||
unfinished_sequences = unfinished_sequences.mul(
|
||||
next_token.tile(eos_token_id_tensor.shape[0], 1)
|
||||
.ne(eos_token_id_tensor.unsqueeze(1))
|
||||
.prod(dim=0)
|
||||
)
|
||||
# stop when each sentence is finished
|
||||
if unfinished_sequences.max() == 0 or stopping_criteria(
|
||||
input_ids, scores
|
||||
):
|
||||
break
|
||||
timesRan = timesRan + 1
|
||||
|
||||
end = time.time()
|
||||
print(
|
||||
"\n\nTime taken is {:.2f} seconds/token\n".format(
|
||||
(end - start) / timesRan
|
||||
)
|
||||
)
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
|
||||
res_str = decode_tokens(tokenizer, res_tokens)
|
||||
yield res_str
|
||||
h2ogpt_model = H2OGPTSHARKModel()
|
||||
|
||||
|
||||
def pad_or_truncate_inputs(
|
||||
@@ -689,6 +498,233 @@ class H2OTextGenerationPipeline(TextGenerationPipeline):
|
||||
)
|
||||
return records
|
||||
|
||||
def generate_new_token(self):
|
||||
model_inputs = self.model.prepare_inputs_for_generation(
|
||||
self.input_ids, **self.model_kwargs
|
||||
)
|
||||
|
||||
outputs = h2ogpt_model.forward(
|
||||
model_inputs["input_ids"], model_inputs["attention_mask"]
|
||||
)
|
||||
|
||||
if args.precision == "fp16":
|
||||
outputs = outputs.to(dtype=torch.float32)
|
||||
next_token_logits = outputs
|
||||
|
||||
# pre-process distribution
|
||||
next_token_scores = self.logits_processor(
|
||||
self.input_ids, next_token_logits
|
||||
)
|
||||
next_token_scores = self.logits_warper(
|
||||
self.input_ids, next_token_scores
|
||||
)
|
||||
|
||||
# sample
|
||||
probs = torch.nn.functional.softmax(next_token_scores, dim=-1)
|
||||
|
||||
next_token = torch.multinomial(probs, num_samples=1).squeeze(1)
|
||||
|
||||
# finished sentences should have their next token be a padding token
|
||||
if self.eos_token_id is not None:
|
||||
if self.pad_token_id is None:
|
||||
raise ValueError(
|
||||
"If `eos_token_id` is defined, make sure that `pad_token_id` is defined."
|
||||
)
|
||||
next_token = (
|
||||
next_token * self.unfinished_sequences
|
||||
+ self.pad_token_id * (1 - self.unfinished_sequences)
|
||||
)
|
||||
|
||||
self.input_ids = torch.cat(
|
||||
[self.input_ids, next_token[:, None]], dim=-1
|
||||
)
|
||||
|
||||
self.model_kwargs["past_key_values"] = None
|
||||
if "attention_mask" in self.model_kwargs:
|
||||
attention_mask = self.model_kwargs["attention_mask"]
|
||||
self.model_kwargs["attention_mask"] = torch.cat(
|
||||
[
|
||||
attention_mask,
|
||||
attention_mask.new_ones((attention_mask.shape[0], 1)),
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
|
||||
self.truncated_input_ids.append(self.input_ids[:, 0])
|
||||
self.input_ids = self.input_ids[:, 1:]
|
||||
self.model_kwargs["attention_mask"] = self.model_kwargs[
|
||||
"attention_mask"
|
||||
][:, 1:]
|
||||
|
||||
return next_token
|
||||
|
||||
def generate_token(self, **generate_kwargs):
|
||||
del generate_kwargs["max_time"]
|
||||
self.truncated_input_ids = []
|
||||
|
||||
generation_config_ = GenerationConfig.from_model_config(
|
||||
self.model.config
|
||||
)
|
||||
generation_config = copy.deepcopy(generation_config_)
|
||||
self.model_kwargs = generation_config.update(**generate_kwargs)
|
||||
|
||||
logits_processor = LogitsProcessorList()
|
||||
self.stopping_criteria = (
|
||||
self.stopping_criteria
|
||||
if self.stopping_criteria is not None
|
||||
else StoppingCriteriaList()
|
||||
)
|
||||
|
||||
eos_token_id = generation_config.eos_token_id
|
||||
generation_config.pad_token_id = eos_token_id
|
||||
|
||||
(
|
||||
inputs_tensor,
|
||||
model_input_name,
|
||||
self.model_kwargs,
|
||||
) = self.model._prepare_model_inputs(
|
||||
None, generation_config.bos_token_id, self.model_kwargs
|
||||
)
|
||||
batch_size = inputs_tensor.shape[0]
|
||||
|
||||
self.model_kwargs[
|
||||
"output_attentions"
|
||||
] = generation_config.output_attentions
|
||||
self.model_kwargs[
|
||||
"output_hidden_states"
|
||||
] = generation_config.output_hidden_states
|
||||
self.model_kwargs["use_cache"] = generation_config.use_cache
|
||||
|
||||
self.input_ids = (
|
||||
inputs_tensor
|
||||
if model_input_name == "input_ids"
|
||||
else self.model_kwargs.pop("input_ids")
|
||||
)
|
||||
|
||||
input_ids_seq_length = self.input_ids.shape[-1]
|
||||
|
||||
generation_config.max_length = (
|
||||
generation_config.max_new_tokens + input_ids_seq_length
|
||||
)
|
||||
|
||||
self.logits_processor = self.model._get_logits_processor(
|
||||
generation_config=generation_config,
|
||||
input_ids_seq_length=input_ids_seq_length,
|
||||
encoder_input_ids=inputs_tensor,
|
||||
prefix_allowed_tokens_fn=None,
|
||||
logits_processor=logits_processor,
|
||||
)
|
||||
|
||||
self.stopping_criteria = self.model._get_stopping_criteria(
|
||||
generation_config=generation_config,
|
||||
stopping_criteria=self.stopping_criteria,
|
||||
)
|
||||
|
||||
self.logits_warper = self.model._get_logits_warper(generation_config)
|
||||
|
||||
(
|
||||
self.input_ids,
|
||||
self.model_kwargs,
|
||||
) = self.model._expand_inputs_for_generation(
|
||||
input_ids=self.input_ids,
|
||||
expand_size=generation_config.num_return_sequences, # 1
|
||||
is_encoder_decoder=self.model.config.is_encoder_decoder, # False
|
||||
**self.model_kwargs,
|
||||
)
|
||||
|
||||
if isinstance(eos_token_id, int):
|
||||
eos_token_id = [eos_token_id]
|
||||
self.eos_token_id_tensor = (
|
||||
torch.tensor(eos_token_id).to(device=tensor_device)
|
||||
if eos_token_id is not None
|
||||
else None
|
||||
)
|
||||
|
||||
self.pad_token_id = generation_config.pad_token_id
|
||||
self.eos_token_id = eos_token_id
|
||||
|
||||
output_scores = generation_config.output_scores # False
|
||||
output_attentions = generation_config.output_attentions # False
|
||||
output_hidden_states = generation_config.output_hidden_states # False
|
||||
return_dict_in_generate = (
|
||||
generation_config.return_dict_in_generate # False
|
||||
)
|
||||
|
||||
# init attention / hidden states / scores tuples
|
||||
self.scores = (
|
||||
() if (return_dict_in_generate and output_scores) else None
|
||||
)
|
||||
decoder_attentions = (
|
||||
() if (return_dict_in_generate and output_attentions) else None
|
||||
)
|
||||
cross_attentions = (
|
||||
() if (return_dict_in_generate and output_attentions) else None
|
||||
)
|
||||
decoder_hidden_states = (
|
||||
() if (return_dict_in_generate and output_hidden_states) else None
|
||||
)
|
||||
|
||||
# keep track of which sequences are already finished
|
||||
self.unfinished_sequences = torch.ones(
|
||||
self.input_ids.shape[0],
|
||||
dtype=torch.long,
|
||||
device=self.input_ids.device,
|
||||
)
|
||||
|
||||
timesRan = 0
|
||||
import time
|
||||
|
||||
start = time.time()
|
||||
print("\n")
|
||||
|
||||
while True:
|
||||
next_token = self.generate_new_token()
|
||||
new_word = self.tokenizer.decode(
|
||||
next_token.cpu().numpy(),
|
||||
add_special_tokens=False,
|
||||
skip_special_tokens=True,
|
||||
clean_up_tokenization_spaces=True,
|
||||
)
|
||||
|
||||
print(f"{new_word}", end="", flush=True)
|
||||
|
||||
# if eos_token was found in one sentence, set sentence to finished
|
||||
if self.eos_token_id_tensor is not None:
|
||||
self.unfinished_sequences = self.unfinished_sequences.mul(
|
||||
next_token.tile(self.eos_token_id_tensor.shape[0], 1)
|
||||
.ne(self.eos_token_id_tensor.unsqueeze(1))
|
||||
.prod(dim=0)
|
||||
)
|
||||
# stop when each sentence is finished
|
||||
if (
|
||||
self.unfinished_sequences.max() == 0
|
||||
or self.stopping_criteria(self.input_ids, self.scores)
|
||||
):
|
||||
break
|
||||
timesRan = timesRan + 1
|
||||
|
||||
end = time.time()
|
||||
print(
|
||||
"\n\nTime taken is {:.2f} seconds/token\n".format(
|
||||
(end - start) / timesRan
|
||||
)
|
||||
)
|
||||
|
||||
self.input_ids = torch.cat(
|
||||
[
|
||||
torch.tensor(self.truncated_input_ids)
|
||||
.to(device=tensor_device)
|
||||
.unsqueeze(dim=0),
|
||||
self.input_ids,
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
|
||||
return self.input_ids
|
||||
|
||||
def _forward(self, model_inputs, **generate_kwargs):
|
||||
if self.can_stop:
|
||||
stopping_criteria = get_stopping(
|
||||
@@ -748,13 +784,19 @@ class H2OTextGenerationPipeline(TextGenerationPipeline):
|
||||
input_ids, attention_mask = pad_or_truncate_inputs(
|
||||
input_ids, attention_mask, max_padding_length=max_padding_length
|
||||
)
|
||||
self.stopping_criteria = generate_kwargs["stopping_criteria"]
|
||||
|
||||
return_dict = {
|
||||
"model": self.model,
|
||||
"tokenizer": self.tokenizer,
|
||||
generated_sequence = self.generate_token(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
**generate_kwargs,
|
||||
)
|
||||
out_b = generated_sequence.shape[0]
|
||||
generated_sequence = generated_sequence.reshape(
|
||||
in_b, out_b // in_b, *generated_sequence.shape[1:]
|
||||
)
|
||||
return {
|
||||
"generated_sequence": generated_sequence,
|
||||
"input_ids": input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"attention_mask": attention_mask,
|
||||
"prompt_text": prompt_text,
|
||||
}
|
||||
return_dict = {**return_dict, **generate_kwargs}
|
||||
return return_dict
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import os
|
||||
import fire
|
||||
|
||||
from gpt_langchain import (
|
||||
path_to_docs,
|
||||
@@ -201,3 +202,7 @@ def make_db_main(
|
||||
if verbose:
|
||||
print("DONE", flush=True)
|
||||
return db, collection_name
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
fire.Fire(make_db_main)
|
||||
|
||||
@@ -1,12 +1,10 @@
|
||||
import argparse
|
||||
import json
|
||||
import re
|
||||
import gc
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
from tqdm import tqdm
|
||||
from typing import List, Tuple
|
||||
import subprocess
|
||||
|
||||
import torch
|
||||
import torch_mlir
|
||||
@@ -27,14 +25,6 @@ from apps.language_models.src.model_wrappers.vicuna_sharded_model import (
|
||||
VicunaNorm,
|
||||
VicunaNormCompiled,
|
||||
)
|
||||
from apps.language_models.src.model_wrappers.vicuna4 import (
|
||||
LlamaModel,
|
||||
EightLayerLayerSV,
|
||||
EightLayerLayerFV,
|
||||
CompiledEightLayerLayerSV,
|
||||
CompiledEightLayerLayer,
|
||||
forward_compressed,
|
||||
)
|
||||
from apps.language_models.src.model_wrappers.vicuna_model import (
|
||||
FirstVicuna,
|
||||
SecondVicuna,
|
||||
@@ -50,13 +40,16 @@ from shark.shark_inference import SharkInference
|
||||
from brevitas_examples.llm.llm_quant.quantize import quantize_model
|
||||
from brevitas_examples.llm.llm_quant.run_utils import get_model_impl
|
||||
|
||||
if __name__ == "__main__":
|
||||
import gc
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
prog="vicuna runner",
|
||||
description="runs a vicuna model",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--precision", "-p", default="int8", help="fp32, fp16, int8, int4"
|
||||
"--precision", "-p", default="fp32", help="fp32, fp16, int8, int4"
|
||||
)
|
||||
parser.add_argument("--device", "-d", default="cuda", help="vulkan, cpu, cuda")
|
||||
parser.add_argument(
|
||||
@@ -121,17 +114,11 @@ parser.add_argument(
|
||||
"--cache_vicunas",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="For debugging purposes, creates a first_{precision}.mlir and second_{precision}.mlir and stores on disk",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--iree_vulkan_target_triple",
|
||||
type=str,
|
||||
default="",
|
||||
help="Specify target triple for vulkan.",
|
||||
help="For debugging purposes, creates a first_{precision}.mlir and second_{precision}.mlir and stores on disk"
|
||||
)
|
||||
|
||||
# fmt: off
|
||||
def quant〇matmul_rhs_group_quant〡shape(lhs: List[int], rhs: List[int], rhs_scale: List[int], rhs_zero_point: List[int], rhs_bit_width: int, rhs_group_size: int) -> List[int]:
|
||||
|
||||
def brevitas〇matmul_rhs_group_quant〡shape(lhs: List[int], rhs: List[int], rhs_scale: List[int], rhs_zero_point: List[int], rhs_bit_width: int, rhs_group_size: int) -> List[int]:
|
||||
if len(lhs) == 3 and len(rhs) == 2:
|
||||
return [lhs[0], lhs[1], rhs[0]]
|
||||
elif len(lhs) == 2 and len(rhs) == 2:
|
||||
@@ -140,21 +127,20 @@ def quant〇matmul_rhs_group_quant〡shape(lhs: List[int], rhs: List[int], rhs_s
|
||||
raise ValueError("Input shapes not supported.")
|
||||
|
||||
|
||||
def quant〇matmul_rhs_group_quant〡dtype(lhs_rank_dtype: Tuple[int, int], rhs_rank_dtype: Tuple[int, int], rhs_scale_rank_dtype: Tuple[int, int], rhs_zero_point_rank_dtype: Tuple[int, int], rhs_bit_width: int, rhs_group_size: int) -> int:
|
||||
def brevitas〇matmul_rhs_group_quant〡dtype(lhs_rank_dtype: Tuple[int, int], rhs_rank_dtype: Tuple[int, int], rhs_scale_rank_dtype: Tuple[int, int], rhs_zero_point_rank_dtype: Tuple[int, int], rhs_bit_width: int, rhs_group_size: int) -> int:
|
||||
# output dtype is the dtype of the lhs float input
|
||||
lhs_rank, lhs_dtype = lhs_rank_dtype
|
||||
return lhs_dtype
|
||||
|
||||
|
||||
def quant〇matmul_rhs_group_quant〡has_value_semantics(lhs, rhs, rhs_scale, rhs_zero_point, rhs_bit_width, rhs_group_size) -> None:
|
||||
def brevitas〇matmul_rhs_group_quant〡has_value_semantics(lhs, rhs, rhs_scale, rhs_zero_point, rhs_bit_width, rhs_group_size) -> None:
|
||||
return
|
||||
|
||||
|
||||
brevitas_matmul_rhs_group_quant_library = [
|
||||
quant〇matmul_rhs_group_quant〡shape,
|
||||
quant〇matmul_rhs_group_quant〡dtype,
|
||||
quant〇matmul_rhs_group_quant〡has_value_semantics]
|
||||
# fmt: on
|
||||
brevitas〇matmul_rhs_group_quant〡shape,
|
||||
brevitas〇matmul_rhs_group_quant〡dtype,
|
||||
brevitas〇matmul_rhs_group_quant〡has_value_semantics]
|
||||
|
||||
|
||||
class VicunaBase(SharkLLMBase):
|
||||
@@ -165,13 +151,11 @@ class VicunaBase(SharkLLMBase):
|
||||
max_num_tokens=512,
|
||||
device="cpu",
|
||||
precision="int8",
|
||||
extra_args_cmd=[],
|
||||
) -> None:
|
||||
super().__init__(model_name, hf_model_path, max_num_tokens)
|
||||
self.max_sequence_length = 256
|
||||
self.device = device
|
||||
self.precision = precision
|
||||
self.extra_args = extra_args_cmd
|
||||
|
||||
def get_tokenizer(self):
|
||||
# Retrieve the tokenizer from Huggingface
|
||||
@@ -192,14 +176,11 @@ class VicunaBase(SharkLLMBase):
|
||||
self, first_vicuna_mlir, second_vicuna_mlir, output_name
|
||||
):
|
||||
print(f"[DEBUG] combining first and second mlir")
|
||||
print(f"[DEBIG] output_name = {output_name}")
|
||||
maps1 = []
|
||||
maps2 = []
|
||||
constants = set()
|
||||
f1 = []
|
||||
f2 = []
|
||||
|
||||
print(f"[DEBUG] processing first vircuna mlir")
|
||||
first_vicuna_mlir = first_vicuna_mlir.splitlines()
|
||||
while first_vicuna_mlir:
|
||||
line = first_vicuna_mlir.pop(0)
|
||||
@@ -212,7 +193,6 @@ class VicunaBase(SharkLLMBase):
|
||||
f1.append(line)
|
||||
f1 = f1[:-1]
|
||||
del first_vicuna_mlir
|
||||
gc.collect()
|
||||
|
||||
for i, map_line in enumerate(maps1):
|
||||
map_var = map_line.split(" ")[0]
|
||||
@@ -223,7 +203,6 @@ class VicunaBase(SharkLLMBase):
|
||||
for func_line in f1
|
||||
]
|
||||
|
||||
print(f"[DEBUG] processing second vircuna mlir")
|
||||
second_vicuna_mlir = second_vicuna_mlir.splitlines()
|
||||
while second_vicuna_mlir:
|
||||
line = second_vicuna_mlir.pop(0)
|
||||
@@ -237,8 +216,6 @@ class VicunaBase(SharkLLMBase):
|
||||
line = re.sub("forward", "second_vicuna_forward", line)
|
||||
f2.append(line)
|
||||
f2 = f2[:-1]
|
||||
del second_vicuna_mlir
|
||||
gc.collect()
|
||||
|
||||
for i, map_line in enumerate(maps2):
|
||||
map_var = map_line.split(" ")[0]
|
||||
@@ -259,7 +236,6 @@ class VicunaBase(SharkLLMBase):
|
||||
global_var_loading1 = []
|
||||
global_var_loading2 = []
|
||||
|
||||
print(f"[DEBUG] processing constants")
|
||||
counter = 0
|
||||
constants = list(constants)
|
||||
while constants:
|
||||
@@ -283,7 +259,7 @@ class VicunaBase(SharkLLMBase):
|
||||
vnames.append(vname)
|
||||
if "true" not in vname:
|
||||
global_vars.append(
|
||||
f"ml_program.global private @{vname}({vbody}) : {fixed_vdtype}"
|
||||
f"ml_program.global public @{vname}({vbody}) : {fixed_vdtype}"
|
||||
)
|
||||
global_var_loading1.append(
|
||||
f"\t\t%{vname} = ml_program.global_load_const @{vname} : {fixed_vdtype}"
|
||||
@@ -293,7 +269,7 @@ class VicunaBase(SharkLLMBase):
|
||||
)
|
||||
else:
|
||||
global_vars.append(
|
||||
f"ml_program.global private @{vname}({vbody}) : i1"
|
||||
f"ml_program.global public @{vname}({vbody}) : i1"
|
||||
)
|
||||
global_var_loading1.append(
|
||||
f"\t\t%{vname} = ml_program.global_load_const @{vname} : i1"
|
||||
@@ -303,7 +279,6 @@ class VicunaBase(SharkLLMBase):
|
||||
)
|
||||
new_f1, new_f2 = [], []
|
||||
|
||||
print(f"[DEBUG] processing f1")
|
||||
for line in f1:
|
||||
if "func.func" in line:
|
||||
new_f1.append(line)
|
||||
@@ -312,7 +287,6 @@ class VicunaBase(SharkLLMBase):
|
||||
else:
|
||||
new_f1.append(line)
|
||||
|
||||
print(f"[DEBUG] processing f2")
|
||||
for line in f2:
|
||||
if "func.func" in line:
|
||||
new_f2.append(line)
|
||||
@@ -331,45 +305,29 @@ class VicunaBase(SharkLLMBase):
|
||||
|
||||
f1 = new_f1
|
||||
f2 = new_f2
|
||||
|
||||
del new_f1
|
||||
del new_f2
|
||||
gc.collect()
|
||||
|
||||
print(
|
||||
[
|
||||
"c20_i64 = arith.addi %dim_i64, %c1_i64 : i64" in x
|
||||
for x in [maps1, maps2, global_vars, f1, f2]
|
||||
]
|
||||
)
|
||||
whole_string = "\n".join(
|
||||
maps1
|
||||
+ maps2
|
||||
+ [module_start]
|
||||
+ global_vars
|
||||
+ f1
|
||||
+ f2
|
||||
+ [module_end]
|
||||
)
|
||||
|
||||
# doing it this way rather than assembling the whole string
|
||||
# to prevent OOM with 64GiB RAM when encoding the file.
|
||||
f_ = open(output_name, "w+")
|
||||
f_.write(whole_string)
|
||||
f_.close()
|
||||
|
||||
print(f"[DEBUG] Saving mlir to {output_name}")
|
||||
with open(output_name, "w+") as f_:
|
||||
f_.writelines(line + "\n" for line in maps1)
|
||||
f_.writelines(line + "\n" for line in maps2)
|
||||
f_.writelines(line + "\n" for line in [module_start])
|
||||
f_.writelines(line + "\n" for line in global_vars)
|
||||
f_.writelines(line + "\n" for line in f1)
|
||||
f_.writelines(line + "\n" for line in f2)
|
||||
f_.writelines(line + "\n" for line in [module_end])
|
||||
return whole_string
|
||||
|
||||
del maps1
|
||||
del maps2
|
||||
del module_start
|
||||
del global_vars
|
||||
del f1
|
||||
del f2
|
||||
del module_end
|
||||
gc.collect()
|
||||
|
||||
print(f"[DEBUG] Reading combined mlir back in")
|
||||
with open(output_name, "rb") as f:
|
||||
return f.read()
|
||||
|
||||
def generate_new_token(self, params, sharded=True, cli=True):
|
||||
def generate_new_token(self, params, sharded=True):
|
||||
is_first = params["is_first"]
|
||||
if is_first:
|
||||
prompt = params["prompt"]
|
||||
@@ -408,6 +366,7 @@ class VicunaBase(SharkLLMBase):
|
||||
_past_key_values = output["past_key_values"]
|
||||
_token = int(torch.argmax(_logits[:, -1, :], dim=1)[0])
|
||||
else:
|
||||
print(len(output))
|
||||
_logits = torch.tensor(output[0])
|
||||
_past_key_values = torch.tensor(output[1:])
|
||||
_token = torch.argmax(_logits[:, -1, :], dim=1)
|
||||
@@ -421,8 +380,7 @@ class VicunaBase(SharkLLMBase):
|
||||
"past_key_values": _past_key_values,
|
||||
}
|
||||
|
||||
if cli:
|
||||
print(f" token : {_token} | detok : {_detok}")
|
||||
print(f" token : {_token} | detok : {_detok}")
|
||||
|
||||
return ret_dict
|
||||
|
||||
@@ -438,17 +396,14 @@ class ShardedVicuna(VicunaBase):
|
||||
precision="fp32",
|
||||
config_json=None,
|
||||
weight_group_size=128,
|
||||
compressed=False,
|
||||
extra_args_cmd=[],
|
||||
) -> None:
|
||||
super().__init__(model_name, hf_model_path, max_num_tokens, extra_args_cmd=extra_args_cmd)
|
||||
super().__init__(model_name, hf_model_path, max_num_tokens)
|
||||
self.max_sequence_length = 256
|
||||
self.device = device
|
||||
self.precision = precision
|
||||
self.tokenizer = self.get_tokenizer()
|
||||
self.config = config_json
|
||||
self.weight_group_size = weight_group_size
|
||||
self.compressed = compressed
|
||||
self.shark_model = self.compile(device=device)
|
||||
|
||||
def get_tokenizer(self):
|
||||
@@ -561,59 +516,6 @@ class ShardedVicuna(VicunaBase):
|
||||
)
|
||||
return mlir_bytecode
|
||||
|
||||
def compile_vicuna_layer4(
|
||||
self,
|
||||
vicuna_layer,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
past_key_values=None,
|
||||
):
|
||||
# Compile a hidden decoder layer of vicuna
|
||||
if past_key_values is None:
|
||||
model_inputs = (hidden_states, attention_mask, position_ids)
|
||||
else:
|
||||
(
|
||||
(pkv00, pkv01),
|
||||
(pkv10, pkv11),
|
||||
(pkv20, pkv21),
|
||||
(pkv30, pkv31),
|
||||
(pkv40, pkv41),
|
||||
(pkv50, pkv51),
|
||||
(pkv60, pkv61),
|
||||
(pkv70, pkv71),
|
||||
) = past_key_values
|
||||
|
||||
model_inputs = (
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
pkv00,
|
||||
pkv01,
|
||||
pkv10,
|
||||
pkv11,
|
||||
pkv20,
|
||||
pkv21,
|
||||
pkv30,
|
||||
pkv31,
|
||||
pkv40,
|
||||
pkv41,
|
||||
pkv50,
|
||||
pkv51,
|
||||
pkv60,
|
||||
pkv61,
|
||||
pkv70,
|
||||
pkv71,
|
||||
)
|
||||
mlir_bytecode = import_with_fx(
|
||||
vicuna_layer,
|
||||
model_inputs,
|
||||
precision=self.precision,
|
||||
f16_input_mask=[False, False],
|
||||
mlir_type="torchscript",
|
||||
)
|
||||
return mlir_bytecode
|
||||
|
||||
def get_device_index(self, layer_string):
|
||||
# Get the device index from the config file
|
||||
# In the event that different device indices are assigned to
|
||||
@@ -647,27 +549,18 @@ class ShardedVicuna(VicunaBase):
|
||||
hidden_states, dynamic_axes=[1]
|
||||
)
|
||||
|
||||
# module = torch_mlir.compile(
|
||||
# lmh,
|
||||
# (hidden_states,),
|
||||
# torch_mlir.OutputType.LINALG_ON_TENSORS,
|
||||
# use_tracing=False,
|
||||
# verbose=False,
|
||||
# )
|
||||
# bytecode_stream = BytesIO()
|
||||
# module.operation.write_bytecode(bytecode_stream)
|
||||
# bytecode = bytecode_stream.getvalue()
|
||||
# f_ = open(mlir_path, "wb")
|
||||
# f_.write(bytecode)
|
||||
# f_.close()
|
||||
filepath = Path("lmhead.mlir")
|
||||
download_public_file(
|
||||
"gs://shark_tank/elias/compressed_sv/lmhead.mlir",
|
||||
filepath.absolute(),
|
||||
single_file=True,
|
||||
module = torch_mlir.compile(
|
||||
lmh,
|
||||
(hidden_states,),
|
||||
torch_mlir.OutputType.LINALG_ON_TENSORS,
|
||||
use_tracing=False,
|
||||
verbose=False,
|
||||
)
|
||||
f_ = open(f"lmhead.mlir", "rb")
|
||||
bytecode = f_.read()
|
||||
bytecode_stream = BytesIO()
|
||||
module.operation.write_bytecode(bytecode_stream)
|
||||
bytecode = bytecode_stream.getvalue()
|
||||
f_ = open(mlir_path, "wb")
|
||||
f_.write(bytecode)
|
||||
f_.close()
|
||||
|
||||
shark_module = SharkInference(
|
||||
@@ -699,21 +592,18 @@ class ShardedVicuna(VicunaBase):
|
||||
hidden_states, dynamic_axes=[1]
|
||||
)
|
||||
|
||||
# module = torch_mlir.compile(
|
||||
# fvn,
|
||||
# (hidden_states,),
|
||||
# torch_mlir.OutputType.LINALG_ON_TENSORS,
|
||||
# use_tracing=False,
|
||||
# verbose=False,
|
||||
# )
|
||||
filepath = Path("norm.mlir")
|
||||
download_public_file(
|
||||
"gs://shark_tank/elias/compressed_sv/norm.mlir",
|
||||
filepath.absolute(),
|
||||
single_file=True,
|
||||
module = torch_mlir.compile(
|
||||
fvn,
|
||||
(hidden_states,),
|
||||
torch_mlir.OutputType.LINALG_ON_TENSORS,
|
||||
use_tracing=False,
|
||||
verbose=False,
|
||||
)
|
||||
f_ = open(f"norm.mlir", "rb")
|
||||
bytecode = f_.read()
|
||||
bytecode_stream = BytesIO()
|
||||
module.operation.write_bytecode(bytecode_stream)
|
||||
bytecode = bytecode_stream.getvalue()
|
||||
f_ = open(mlir_path, "wb")
|
||||
f_.write(bytecode)
|
||||
f_.close()
|
||||
|
||||
shark_module = SharkInference(
|
||||
@@ -744,27 +634,18 @@ class ShardedVicuna(VicunaBase):
|
||||
input_ids = torch_mlir.TensorPlaceholder.like(
|
||||
input_ids, dynamic_axes=[1]
|
||||
)
|
||||
# module = torch_mlir.compile(
|
||||
# fve,
|
||||
# (input_ids,),
|
||||
# torch_mlir.OutputType.LINALG_ON_TENSORS,
|
||||
# use_tracing=False,
|
||||
# verbose=False,
|
||||
# )
|
||||
# bytecode_stream = BytesIO()
|
||||
# module.operation.write_bytecode(bytecode_stream)
|
||||
# bytecode = bytecode_stream.getvalue()
|
||||
# f_ = open(mlir_path, "wb")
|
||||
# f_.write(bytecode)
|
||||
# f_.close()
|
||||
filepath = Path("embedding.mlir")
|
||||
download_public_file(
|
||||
"gs://shark_tank/elias/compressed_sv/embedding.mlir",
|
||||
filepath.absolute(),
|
||||
single_file=True,
|
||||
module = torch_mlir.compile(
|
||||
fve,
|
||||
(input_ids,),
|
||||
torch_mlir.OutputType.LINALG_ON_TENSORS,
|
||||
use_tracing=False,
|
||||
verbose=False,
|
||||
)
|
||||
f_ = open(f"embedding.mlir", "rb")
|
||||
bytecode = f_.read()
|
||||
bytecode_stream = BytesIO()
|
||||
module.operation.write_bytecode(bytecode_stream)
|
||||
bytecode = bytecode_stream.getvalue()
|
||||
f_ = open(mlir_path, "wb")
|
||||
f_.write(bytecode)
|
||||
f_.close()
|
||||
|
||||
shark_module = SharkInference(
|
||||
@@ -838,7 +719,7 @@ class ShardedVicuna(VicunaBase):
|
||||
inputs0[2],
|
||||
),
|
||||
output_type="torch",
|
||||
backend_legal_ops=["quant.matmul_rhs_group_quant"],
|
||||
backend_legal_ops=["brevitas.matmul_rhs_group_quant"],
|
||||
extra_library=brevitas_matmul_rhs_group_quant_library,
|
||||
use_tracing=False,
|
||||
verbose=False,
|
||||
@@ -882,7 +763,7 @@ class ShardedVicuna(VicunaBase):
|
||||
pkv1_placeholder,
|
||||
),
|
||||
output_type="torch",
|
||||
backend_legal_ops=["quant.matmul_rhs_group_quant"],
|
||||
backend_legal_ops=["brevitas.matmul_rhs_group_quant"],
|
||||
extra_library=brevitas_matmul_rhs_group_quant_library,
|
||||
use_tracing=False,
|
||||
verbose=False,
|
||||
@@ -945,87 +826,17 @@ class ShardedVicuna(VicunaBase):
|
||||
"--iree-vm-target-truncate-unsupported-floats",
|
||||
"--iree-codegen-check-ir-before-llvm-conversion=false",
|
||||
"--iree-vm-bytecode-module-output-format=flatbuffer-binary",
|
||||
] + self.extra_args,
|
||||
],
|
||||
)
|
||||
module.load_module(vmfb_path)
|
||||
modules.append(module)
|
||||
return mlirs, modules
|
||||
|
||||
def compile_to_vmfb_one_model4(
|
||||
self, inputs0, layers0, inputs1, layers1, device="cpu"
|
||||
):
|
||||
mlirs, modules = [], []
|
||||
assert len(layers0) == len(layers1)
|
||||
for layer0, layer1, idx in zip(layers0, layers1, range(len(layers0))):
|
||||
mlir_path = Path(f"{idx}_full.mlir")
|
||||
vmfb_path = Path(f"{idx}_full.vmfb")
|
||||
# if vmfb_path.exists():
|
||||
# continue
|
||||
if mlir_path.exists():
|
||||
# print(f"Found layer {idx} mlir")
|
||||
f_ = open(mlir_path, "rb")
|
||||
bytecode = f_.read()
|
||||
f_.close()
|
||||
mlirs.append(bytecode)
|
||||
else:
|
||||
filepath = Path(f"{idx}_full.mlir")
|
||||
download_public_file(
|
||||
f"gs://shark_tank/elias/compressed_sv/{idx}_full.mlir",
|
||||
filepath.absolute(),
|
||||
single_file=True,
|
||||
)
|
||||
|
||||
f_ = open(f"{idx}_full.mlir", "rb")
|
||||
bytecode = f_.read()
|
||||
f_.close()
|
||||
mlirs.append(bytecode)
|
||||
|
||||
if vmfb_path.exists():
|
||||
# print(f"Found layer {idx} vmfb")
|
||||
device_idx = self.get_device_index(
|
||||
f"first_vicuna.model.model.layers.{idx}[\s.$]"
|
||||
)
|
||||
module = SharkInference(
|
||||
None,
|
||||
device=device,
|
||||
device_idx=0,
|
||||
mlir_dialect="tm_tensor",
|
||||
mmap=True,
|
||||
)
|
||||
module.load_module(vmfb_path)
|
||||
else:
|
||||
print(f"Compiling layer {idx} vmfb")
|
||||
device_idx = self.get_device_index(
|
||||
f"first_vicuna.model.model.layers.{idx}[\s.$]"
|
||||
)
|
||||
module = SharkInference(
|
||||
mlirs[idx],
|
||||
device=device,
|
||||
device_idx=0,
|
||||
mlir_dialect="tm_tensor",
|
||||
mmap=True,
|
||||
)
|
||||
module.save_module(
|
||||
module_name=f"{idx}_full",
|
||||
extra_args=[
|
||||
"--iree-vm-target-truncate-unsupported-floats",
|
||||
"--iree-codegen-check-ir-before-llvm-conversion=false",
|
||||
"--iree-vm-bytecode-module-output-format=flatbuffer-binary",
|
||||
] + self.extra_args,
|
||||
)
|
||||
module.load_module(vmfb_path)
|
||||
modules.append(module)
|
||||
return mlirs, modules
|
||||
|
||||
def get_sharded_model(self, device="cpu", compressed=False):
|
||||
def get_sharded_model(self, device="cpu"):
|
||||
# SAMPLE_INPUT_LEN is used for creating mlir with dynamic inputs, which is currently an increadibly hacky proccess
|
||||
# please don't change it
|
||||
SAMPLE_INPUT_LEN = 137
|
||||
vicuna_model = self.get_src_model()
|
||||
if compressed:
|
||||
vicuna_model.model = LlamaModel.from_pretrained(
|
||||
"TheBloke/vicuna-7B-1.1-HF"
|
||||
)
|
||||
|
||||
if self.precision in ["int4", "int8"]:
|
||||
print("Applying weight quantization..")
|
||||
@@ -1033,38 +844,16 @@ class ShardedVicuna(VicunaBase):
|
||||
quantize_model(
|
||||
get_model_impl(vicuna_model).layers,
|
||||
dtype=torch.float32,
|
||||
weight_quant_type="asym",
|
||||
weight_bit_width=weight_bit_width,
|
||||
weight_param_method="stats",
|
||||
weight_scale_precision="float",
|
||||
weight_quant_type="asym",
|
||||
weight_quant_granularity="per_group",
|
||||
weight_group_size=self.weight_group_size,
|
||||
quantize_weight_zero_point=False,
|
||||
input_bit_width=None,
|
||||
input_scale_type="float",
|
||||
input_param_method="stats",
|
||||
input_quant_type="asym",
|
||||
input_quant_granularity="per_tensor",
|
||||
quantize_input_zero_point=False,
|
||||
seqlen=2048,
|
||||
)
|
||||
print("Weight quantization applied.")
|
||||
|
||||
placeholder_pkv_segment = tuple(
|
||||
(
|
||||
torch.zeros([1, 32, SAMPLE_INPUT_LEN, 128]),
|
||||
torch.zeros([1, 32, SAMPLE_INPUT_LEN, 128]),
|
||||
)
|
||||
for _ in range(8)
|
||||
)
|
||||
placeholder_pkv_full = tuple(
|
||||
(
|
||||
torch.zeros([1, 32, SAMPLE_INPUT_LEN, 128]),
|
||||
torch.zeros([1, 32, SAMPLE_INPUT_LEN, 128]),
|
||||
)
|
||||
for _ in range(32)
|
||||
)
|
||||
|
||||
placeholder_input0 = (
|
||||
torch.zeros([1, SAMPLE_INPUT_LEN, 4096]),
|
||||
torch.zeros([1, 1, SAMPLE_INPUT_LEN, SAMPLE_INPUT_LEN]),
|
||||
@@ -1115,39 +904,20 @@ class ShardedVicuna(VicunaBase):
|
||||
device_idx=device_idx,
|
||||
)
|
||||
|
||||
if not compressed:
|
||||
layers0 = [
|
||||
FirstVicunaLayer(layer) for layer in vicuna_model.model.layers
|
||||
]
|
||||
layers1 = [
|
||||
SecondVicunaLayer(layer) for layer in vicuna_model.model.layers
|
||||
]
|
||||
|
||||
else:
|
||||
layers00 = EightLayerLayerFV(vicuna_model.model.layers[0:8])
|
||||
layers01 = EightLayerLayerFV(vicuna_model.model.layers[8:16])
|
||||
layers02 = EightLayerLayerFV(vicuna_model.model.layers[16:24])
|
||||
layers03 = EightLayerLayerFV(vicuna_model.model.layers[24:32])
|
||||
layers10 = EightLayerLayerSV(vicuna_model.model.layers[0:8])
|
||||
layers11 = EightLayerLayerSV(vicuna_model.model.layers[8:16])
|
||||
layers12 = EightLayerLayerSV(vicuna_model.model.layers[16:24])
|
||||
layers13 = EightLayerLayerSV(vicuna_model.model.layers[24:32])
|
||||
layers0 = [layers00, layers01, layers02, layers03]
|
||||
layers1 = [layers10, layers11, layers12, layers13]
|
||||
|
||||
_, modules = self.compile_to_vmfb_one_model4(
|
||||
layers0 = [
|
||||
FirstVicunaLayer(layer) for layer in vicuna_model.model.layers
|
||||
]
|
||||
layers1 = [
|
||||
SecondVicunaLayer(layer) for layer in vicuna_model.model.layers
|
||||
]
|
||||
_, modules = self.compile_to_vmfb_one_model(
|
||||
placeholder_input0,
|
||||
layers0,
|
||||
placeholder_input1,
|
||||
layers1,
|
||||
device=device,
|
||||
)
|
||||
|
||||
if not compressed:
|
||||
shark_layers = [CompiledVicunaLayer(m) for m in modules]
|
||||
else:
|
||||
shark_layers = [CompiledEightLayerLayer(m) for m in modules]
|
||||
vicuna_model.model.compressedlayers = shark_layers
|
||||
shark_layers = [CompiledVicunaLayer(m) for m in modules]
|
||||
|
||||
sharded_model = ShardedVicunaModel(
|
||||
vicuna_model,
|
||||
@@ -1159,18 +929,11 @@ class ShardedVicuna(VicunaBase):
|
||||
return sharded_model
|
||||
|
||||
def compile(self, device="cpu"):
|
||||
return self.get_sharded_model(
|
||||
device=device, compressed=self.compressed
|
||||
)
|
||||
return self.get_sharded_model(
|
||||
device=device, compressed=self.compressed
|
||||
)
|
||||
return self.get_sharded_model(device=device)
|
||||
|
||||
def generate(self, prompt, cli=False):
|
||||
def generate(self, prompt, cli=True):
|
||||
# TODO: refactor for cleaner integration
|
||||
|
||||
history = []
|
||||
|
||||
tokens_generated = []
|
||||
_past_key_values = None
|
||||
_token = None
|
||||
@@ -1188,8 +951,6 @@ class ShardedVicuna(VicunaBase):
|
||||
_token = generated_token_op["token"]
|
||||
_past_key_values = generated_token_op["past_key_values"]
|
||||
_detok = generated_token_op["detok"]
|
||||
history.append(_token)
|
||||
yield self.tokenizer.decode(history)
|
||||
|
||||
if _token == 2:
|
||||
break
|
||||
@@ -1200,7 +961,7 @@ class ShardedVicuna(VicunaBase):
|
||||
if type(tokens_generated[i]) != int:
|
||||
tokens_generated[i] = int(tokens_generated[i][0])
|
||||
result_output = self.tokenizer.decode(tokens_generated)
|
||||
yield result_output
|
||||
return result_output
|
||||
|
||||
def autocomplete(self, prompt):
|
||||
# use First vic alone to complete a story / prompt / sentence.
|
||||
@@ -1223,9 +984,8 @@ class UnshardedVicuna(VicunaBase):
|
||||
weight_group_size=128,
|
||||
download_vmfb=False,
|
||||
cache_vicunas=False,
|
||||
extra_args_cmd=[],
|
||||
) -> None:
|
||||
super().__init__(model_name, hf_model_path, max_num_tokens, extra_args_cmd=extra_args_cmd)
|
||||
super().__init__(model_name, hf_model_path, max_num_tokens)
|
||||
if "llama2" in self.model_name and hf_auth_token == None:
|
||||
raise ValueError(
|
||||
"HF auth token required. Pass it using --hf_auth_token flag."
|
||||
@@ -1422,10 +1182,11 @@ class UnshardedVicuna(VicunaBase):
|
||||
else:
|
||||
compilation_prompt = "".join(["0" for _ in range(17)])
|
||||
|
||||
if Path(f"first_{self.precision}.mlir").exists():
|
||||
|
||||
if Path(f'first_{self.precision}.mlir').exists():
|
||||
print(f"loading first_{self.precision}.mlir")
|
||||
with open(Path(f"first_{self.precision}.mlir"), "r") as f:
|
||||
first_module = f.read()
|
||||
first_module = f.read()
|
||||
else:
|
||||
# generate first vicuna
|
||||
compilation_input_ids = self.tokenizer(
|
||||
@@ -1469,7 +1230,7 @@ class UnshardedVicuna(VicunaBase):
|
||||
[*firstVicunaCompileInput],
|
||||
output_type=torch_mlir.OutputType.TORCH,
|
||||
backend_legal_ops=[
|
||||
"quant.matmul_rhs_group_quant"
|
||||
"brevitas.matmul_rhs_group_quant"
|
||||
],
|
||||
extra_library=brevitas_matmul_rhs_group_quant_library,
|
||||
use_tracing=False,
|
||||
@@ -1490,9 +1251,6 @@ class UnshardedVicuna(VicunaBase):
|
||||
verbose=False,
|
||||
)
|
||||
del ts_graph
|
||||
del firstVicunaCompileInput
|
||||
gc.collect()
|
||||
|
||||
print(
|
||||
"[DEBUG] successfully generated first vicuna linalg mlir"
|
||||
)
|
||||
@@ -1556,7 +1314,7 @@ class UnshardedVicuna(VicunaBase):
|
||||
[*secondVicunaCompileInput],
|
||||
output_type=torch_mlir.OutputType.TORCH,
|
||||
backend_legal_ops=[
|
||||
"quant.matmul_rhs_group_quant"
|
||||
"brevitas.matmul_rhs_group_quant"
|
||||
],
|
||||
extra_library=brevitas_matmul_rhs_group_quant_library,
|
||||
use_tracing=False,
|
||||
@@ -1577,8 +1335,6 @@ class UnshardedVicuna(VicunaBase):
|
||||
verbose=False,
|
||||
)
|
||||
del ts_graph
|
||||
del secondVicunaCompileInput
|
||||
gc.collect()
|
||||
print(
|
||||
"[DEBUG] successfully generated second vicuna linalg mlir"
|
||||
)
|
||||
@@ -1606,7 +1362,7 @@ class UnshardedVicuna(VicunaBase):
|
||||
"--iree-vm-target-truncate-unsupported-floats",
|
||||
"--iree-codegen-check-ir-before-llvm-conversion=false",
|
||||
"--iree-vm-bytecode-module-output-format=flatbuffer-binary",
|
||||
] + self.extra_args,
|
||||
],
|
||||
)
|
||||
print("Saved vic vmfb at ", str(path))
|
||||
shark_module.load_module(path)
|
||||
@@ -1623,22 +1379,23 @@ class UnshardedVicuna(VicunaBase):
|
||||
)
|
||||
return res_str
|
||||
|
||||
def generate(self, prompt, cli):
|
||||
def generate(self, prompt, cli=True):
|
||||
# TODO: refactor for cleaner integration
|
||||
import gc
|
||||
if self.shark_model is None:
|
||||
self.compile()
|
||||
res_tokens = []
|
||||
params = {"prompt": prompt, "is_first": True, "fv": self.shark_model}
|
||||
|
||||
generated_token_op = self.generate_new_token(
|
||||
params=params, sharded=False, cli=cli
|
||||
params=params, sharded=False
|
||||
)
|
||||
|
||||
token = generated_token_op["token"]
|
||||
logits = generated_token_op["logits"]
|
||||
pkv = generated_token_op["past_key_values"]
|
||||
detok = generated_token_op["detok"]
|
||||
yield detok, ""
|
||||
yield detok
|
||||
|
||||
res_tokens.append(token)
|
||||
if cli:
|
||||
@@ -1654,7 +1411,7 @@ class UnshardedVicuna(VicunaBase):
|
||||
}
|
||||
|
||||
generated_token_op = self.generate_new_token(
|
||||
params=params, sharded=False, cli=cli
|
||||
params=params, sharded=False
|
||||
)
|
||||
|
||||
token = generated_token_op["token"]
|
||||
@@ -1671,89 +1428,23 @@ class UnshardedVicuna(VicunaBase):
|
||||
else:
|
||||
if cli:
|
||||
print(f"{detok}", end=" ", flush=True)
|
||||
yield detok, ""
|
||||
|
||||
if len(res_tokens) % 3 == 0:
|
||||
part_str = self.decode_tokens(res_tokens)
|
||||
yield part_str
|
||||
|
||||
res_str = self.decode_tokens(res_tokens)
|
||||
# print(f"[DEBUG] final output : \n{res_str}")
|
||||
yield res_str, "formatted"
|
||||
yield res_str
|
||||
|
||||
def autocomplete(self, prompt):
|
||||
# use First vic alone to complete a story / prompt / sentence.
|
||||
pass
|
||||
|
||||
# NOTE: Each `model_name` should have its own start message
|
||||
start_message = {
|
||||
"llama2_7b": (
|
||||
"System: You are a helpful, respectful and honest assistant. Always answer "
|
||||
"as helpfully as possible, while being safe. Your answers should not "
|
||||
"include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal "
|
||||
"content. Please ensure that your responses are socially unbiased and positive "
|
||||
"in nature. If a question does not make any sense, or is not factually coherent, "
|
||||
"explain why instead of answering something not correct. If you don't know the "
|
||||
"answer to a question, please don't share false information."
|
||||
),
|
||||
"llama2_70b": (
|
||||
"System: You are a helpful, respectful and honest assistant. Always answer "
|
||||
"as helpfully as possible, while being safe. Your answers should not "
|
||||
"include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal "
|
||||
"content. Please ensure that your responses are socially unbiased and positive "
|
||||
"in nature. If a question does not make any sense, or is not factually coherent, "
|
||||
"explain why instead of answering something not correct. If you don't know the "
|
||||
"answer to a question, please don't share false information."
|
||||
),
|
||||
"StableLM": (
|
||||
"<|SYSTEM|># StableLM Tuned (Alpha version)"
|
||||
"\n- StableLM is a helpful and harmless open-source AI language model "
|
||||
"developed by StabilityAI."
|
||||
"\n- StableLM is excited to be able to help the user, but will refuse "
|
||||
"to do anything that could be considered harmful to the user."
|
||||
"\n- StableLM is more than just an information source, StableLM is also "
|
||||
"able to write poetry, short stories, and make jokes."
|
||||
"\n- StableLM will refuse to participate in anything that "
|
||||
"could harm a human."
|
||||
),
|
||||
"vicuna": (
|
||||
"A chat between a curious user and an artificial intelligence assistant. "
|
||||
"The assistant gives helpful, detailed, and polite answers to the user's "
|
||||
"questions.\n"
|
||||
),
|
||||
"vicuna4": (
|
||||
"A chat between a curious user and an artificial intelligence assistant. "
|
||||
"The assistant gives helpful, detailed, and polite answers to the user's "
|
||||
"questions.\n"
|
||||
),
|
||||
"vicuna1p3": (
|
||||
"A chat between a curious user and an artificial intelligence assistant. "
|
||||
"The assistant gives helpful, detailed, and polite answers to the user's "
|
||||
"questions.\n"
|
||||
),
|
||||
"codegen": "",
|
||||
}
|
||||
|
||||
def create_prompt(model_name, history):
|
||||
global start_message
|
||||
system_message = start_message[model_name]
|
||||
conversation = "".join(
|
||||
[
|
||||
"".join(["<|USER|>" + item[0], "<|ASSISTANT|>" + item[1]])
|
||||
for item in history
|
||||
]
|
||||
)
|
||||
msg = system_message + conversation
|
||||
msg = msg.strip()
|
||||
return msg
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args, unknown = parser.parse_known_args()
|
||||
|
||||
_extra_args = []
|
||||
# vulkan target triple
|
||||
if args.iree_vulkan_target_triple != "":
|
||||
_extra_args.append(
|
||||
f"-iree-vulkan-target-triple={args.iree_vulkan_target_triple}"
|
||||
)
|
||||
|
||||
vic = None
|
||||
if not args.sharded:
|
||||
vic_mlir_path = (
|
||||
@@ -1777,7 +1468,6 @@ if __name__ == "__main__":
|
||||
weight_group_size=args.weight_group_size,
|
||||
download_vmfb=args.download_vmfb,
|
||||
cache_vicunas=args.cache_vicunas,
|
||||
extra_args_cmd=_extra_args,
|
||||
)
|
||||
else:
|
||||
if args.config is not None:
|
||||
@@ -1792,7 +1482,6 @@ if __name__ == "__main__":
|
||||
precision=args.precision,
|
||||
config_json=config_json,
|
||||
weight_group_size=args.weight_group_size,
|
||||
extra_args_cmd=_extra_args,
|
||||
)
|
||||
if args.model_name == "vicuna":
|
||||
system_message = "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.\n"
|
||||
@@ -1806,7 +1495,10 @@ if __name__ == "__main__":
|
||||
answer to a question, please don't share false information."""
|
||||
prologue_prompt = "ASSISTANT:\n"
|
||||
|
||||
from apps.stable_diffusion.web.ui.stablelm_ui import chat, set_vicuna_model
|
||||
|
||||
history = []
|
||||
set_vicuna_model(vic)
|
||||
|
||||
model_list = {
|
||||
"vicuna": "vicuna=>TheBloke/vicuna-7B-1.1-HF",
|
||||
@@ -1817,8 +1509,13 @@ if __name__ == "__main__":
|
||||
# TODO: Add break condition from user input
|
||||
user_prompt = input("User: ")
|
||||
history.append([user_prompt, ""])
|
||||
prompt = create_prompt(args.model_name, history)
|
||||
for text, msg in vic.generate(prompt, cli=True):
|
||||
if "formatted" in msg:
|
||||
print("Response:",text)
|
||||
history[-1][1] = text
|
||||
history = list(
|
||||
chat(
|
||||
system_message,
|
||||
history,
|
||||
model=model_list[args.model_name],
|
||||
device=args.device,
|
||||
precision=args.precision,
|
||||
cli=args.cli,
|
||||
)
|
||||
)[0]
|
||||
|
||||
@@ -1,879 +0,0 @@
|
||||
import argparse
|
||||
import json
|
||||
import re
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
from tqdm import tqdm
|
||||
from typing import List, Optional, Tuple, Union
|
||||
import numpy as np
|
||||
import iree.runtime
|
||||
import itertools
|
||||
import subprocess
|
||||
|
||||
import torch
|
||||
import torch_mlir
|
||||
from torch_mlir import TensorPlaceholder
|
||||
from torch_mlir.compiler_utils import run_pipeline_with_repro_report
|
||||
from transformers import (
|
||||
AutoTokenizer,
|
||||
AutoModelForCausalLM,
|
||||
LlamaPreTrainedModel,
|
||||
)
|
||||
from transformers.modeling_outputs import (
|
||||
BaseModelOutputWithPast,
|
||||
CausalLMOutputWithPast,
|
||||
SequenceClassifierOutputWithPast,
|
||||
)
|
||||
from transformers.modeling_utils import PreTrainedModel
|
||||
from transformers.utils import (
|
||||
add_start_docstrings,
|
||||
add_start_docstrings_to_model_forward,
|
||||
logging,
|
||||
replace_return_docstrings,
|
||||
)
|
||||
|
||||
from apps.language_models.src.pipelines.SharkLLMBase import SharkLLMBase
|
||||
from apps.language_models.src.model_wrappers.vicuna_sharded_model import (
|
||||
FirstVicunaLayer,
|
||||
SecondVicunaLayer,
|
||||
CompiledVicunaLayer,
|
||||
ShardedVicunaModel,
|
||||
LMHead,
|
||||
LMHeadCompiled,
|
||||
VicunaEmbedding,
|
||||
VicunaEmbeddingCompiled,
|
||||
VicunaNorm,
|
||||
VicunaNormCompiled,
|
||||
)
|
||||
from apps.language_models.src.model_wrappers.vicuna_model import (
|
||||
FirstVicuna,
|
||||
SecondVicuna,
|
||||
)
|
||||
from apps.language_models.utils import (
|
||||
get_vmfb_from_path,
|
||||
)
|
||||
from shark.shark_downloader import download_public_file
|
||||
from shark.shark_importer import get_f16_inputs
|
||||
from shark.shark_importer import import_with_fx
|
||||
from shark.shark_inference import SharkInference
|
||||
|
||||
from brevitas_examples.llm.llm_quant.quantize import quantize_model
|
||||
from brevitas_examples.llm.llm_quant.run_utils import get_model_impl
|
||||
from transformers.models.llama.configuration_llama import LlamaConfig
|
||||
from transformers.models.llama.modeling_llama import (
|
||||
LlamaDecoderLayer,
|
||||
LlamaRMSNorm,
|
||||
_make_causal_mask,
|
||||
_expand_mask,
|
||||
)
|
||||
from torch import nn
|
||||
from time import time
|
||||
|
||||
|
||||
class LlamaModel(LlamaPreTrainedModel):
|
||||
"""
|
||||
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`]
|
||||
|
||||
Args:
|
||||
config: LlamaConfig
|
||||
"""
|
||||
|
||||
def __init__(self, config: LlamaConfig):
|
||||
super().__init__(config)
|
||||
self.padding_idx = config.pad_token_id
|
||||
self.vocab_size = config.vocab_size
|
||||
|
||||
self.embed_tokens = nn.Embedding(
|
||||
config.vocab_size, config.hidden_size, self.padding_idx
|
||||
)
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
LlamaDecoderLayer(config)
|
||||
for _ in range(config.num_hidden_layers)
|
||||
]
|
||||
)
|
||||
self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.embed_tokens
|
||||
|
||||
def set_input_embeddings(self, value):
|
||||
self.embed_tokens = value
|
||||
|
||||
# Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask
|
||||
def _prepare_decoder_attention_mask(
|
||||
self,
|
||||
attention_mask,
|
||||
input_shape,
|
||||
inputs_embeds,
|
||||
past_key_values_length,
|
||||
):
|
||||
# create causal mask
|
||||
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
||||
combined_attention_mask = None
|
||||
if input_shape[-1] > 1:
|
||||
combined_attention_mask = _make_causal_mask(
|
||||
input_shape,
|
||||
inputs_embeds.dtype,
|
||||
device=inputs_embeds.device,
|
||||
past_key_values_length=past_key_values_length,
|
||||
)
|
||||
|
||||
if attention_mask is not None:
|
||||
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
||||
expanded_attn_mask = _expand_mask(
|
||||
attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
|
||||
).to(inputs_embeds.device)
|
||||
combined_attention_mask = (
|
||||
expanded_attn_mask
|
||||
if combined_attention_mask is None
|
||||
else expanded_attn_mask + combined_attention_mask
|
||||
)
|
||||
|
||||
return combined_attention_mask
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
):
|
||||
t1 = time()
|
||||
output_attentions = (
|
||||
output_attentions
|
||||
if output_attentions is not None
|
||||
else self.config.output_attentions
|
||||
)
|
||||
output_hidden_states = (
|
||||
output_hidden_states
|
||||
if output_hidden_states is not None
|
||||
else self.config.output_hidden_states
|
||||
)
|
||||
use_cache = (
|
||||
use_cache if use_cache is not None else self.config.use_cache
|
||||
)
|
||||
|
||||
return_dict = (
|
||||
return_dict
|
||||
if return_dict is not None
|
||||
else self.config.use_return_dict
|
||||
)
|
||||
|
||||
# retrieve input_ids and inputs_embeds
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
raise ValueError(
|
||||
"You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time"
|
||||
)
|
||||
elif input_ids is not None:
|
||||
batch_size, seq_length = input_ids.shape
|
||||
elif inputs_embeds is not None:
|
||||
batch_size, seq_length, _ = inputs_embeds.shape
|
||||
else:
|
||||
raise ValueError(
|
||||
"You have to specify either decoder_input_ids or decoder_inputs_embeds"
|
||||
)
|
||||
|
||||
seq_length_with_past = seq_length
|
||||
past_key_values_length = 0
|
||||
|
||||
if past_key_values is not None:
|
||||
past_key_values_length = past_key_values[0][0].shape[2]
|
||||
seq_length_with_past = (
|
||||
seq_length_with_past + past_key_values_length
|
||||
)
|
||||
|
||||
if position_ids is None:
|
||||
device = (
|
||||
input_ids.device
|
||||
if input_ids is not None
|
||||
else inputs_embeds.device
|
||||
)
|
||||
position_ids = torch.arange(
|
||||
past_key_values_length,
|
||||
seq_length + past_key_values_length,
|
||||
dtype=torch.long,
|
||||
device=device,
|
||||
)
|
||||
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
|
||||
else:
|
||||
position_ids = position_ids.view(-1, seq_length).long()
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
# embed positions
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones(
|
||||
(batch_size, seq_length_with_past),
|
||||
dtype=torch.bool,
|
||||
device=inputs_embeds.device,
|
||||
)
|
||||
|
||||
attention_mask = self._prepare_decoder_attention_mask(
|
||||
attention_mask,
|
||||
(batch_size, seq_length),
|
||||
inputs_embeds,
|
||||
past_key_values_length,
|
||||
)
|
||||
|
||||
hidden_states = inputs_embeds
|
||||
|
||||
# decoder layers
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_self_attns = () if output_attentions else None
|
||||
next_decoder_cache = () if use_cache else None
|
||||
|
||||
for idx, decoder_layer in enumerate(self.compressedlayers):
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
past_key_value = (
|
||||
past_key_values[8 * idx : 8 * (idx + 1)]
|
||||
if past_key_values is not None
|
||||
else None
|
||||
)
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
# None for past_key_value
|
||||
return module(*inputs, output_attentions, None)
|
||||
|
||||
return custom_forward
|
||||
|
||||
layer_outputs = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(decoder_layer),
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
None,
|
||||
)
|
||||
else:
|
||||
layer_outputs = decoder_layer.forward(
|
||||
hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
if use_cache:
|
||||
next_decoder_cache += (layer_outputs[1:],)
|
||||
|
||||
if output_attentions:
|
||||
all_self_attns += (layer_outputs[1],)
|
||||
|
||||
try:
|
||||
hidden_states = np.asarray(hidden_states, hidden_states.dtype)
|
||||
except:
|
||||
_ = 10
|
||||
|
||||
hidden_states = self.norm(hidden_states)
|
||||
|
||||
# add hidden states from the last decoder layer
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
next_cache = next_decoder_cache if use_cache else None
|
||||
next_cache = tuple(itertools.chain.from_iterable(next_cache))
|
||||
print(f"Token generated in {time() - t1} seconds")
|
||||
if not return_dict:
|
||||
return tuple(
|
||||
v
|
||||
for v in [
|
||||
hidden_states,
|
||||
next_cache,
|
||||
all_hidden_states,
|
||||
all_self_attns,
|
||||
]
|
||||
if v is not None
|
||||
)
|
||||
return BaseModelOutputWithPast(
|
||||
last_hidden_state=hidden_states,
|
||||
past_key_values=next_cache,
|
||||
hidden_states=all_hidden_states,
|
||||
attentions=all_self_attns,
|
||||
)
|
||||
|
||||
|
||||
class EightLayerLayerSV(torch.nn.Module):
|
||||
def __init__(self, layers):
|
||||
super().__init__()
|
||||
assert len(layers) == 8
|
||||
self.layers = layers
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
pkv00,
|
||||
pkv01,
|
||||
pkv10,
|
||||
pkv11,
|
||||
pkv20,
|
||||
pkv21,
|
||||
pkv30,
|
||||
pkv31,
|
||||
pkv40,
|
||||
pkv41,
|
||||
pkv50,
|
||||
pkv51,
|
||||
pkv60,
|
||||
pkv61,
|
||||
pkv70,
|
||||
pkv71,
|
||||
):
|
||||
pkvs = [
|
||||
(pkv00, pkv01),
|
||||
(pkv10, pkv11),
|
||||
(pkv20, pkv21),
|
||||
(pkv30, pkv31),
|
||||
(pkv40, pkv41),
|
||||
(pkv50, pkv51),
|
||||
(pkv60, pkv61),
|
||||
(pkv70, pkv71),
|
||||
]
|
||||
new_pkvs = []
|
||||
for layer, pkv in zip(self.layers, pkvs):
|
||||
outputs = layer(
|
||||
hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=(
|
||||
pkv[0],
|
||||
pkv[1],
|
||||
),
|
||||
use_cache=True,
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
new_pkvs.append(
|
||||
(
|
||||
outputs[-1][0],
|
||||
outputs[-1][1],
|
||||
)
|
||||
)
|
||||
(
|
||||
(new_pkv00, new_pkv01),
|
||||
(new_pkv10, new_pkv11),
|
||||
(new_pkv20, new_pkv21),
|
||||
(new_pkv30, new_pkv31),
|
||||
(new_pkv40, new_pkv41),
|
||||
(new_pkv50, new_pkv51),
|
||||
(new_pkv60, new_pkv61),
|
||||
(new_pkv70, new_pkv71),
|
||||
) = new_pkvs
|
||||
return (
|
||||
hidden_states,
|
||||
new_pkv00,
|
||||
new_pkv01,
|
||||
new_pkv10,
|
||||
new_pkv11,
|
||||
new_pkv20,
|
||||
new_pkv21,
|
||||
new_pkv30,
|
||||
new_pkv31,
|
||||
new_pkv40,
|
||||
new_pkv41,
|
||||
new_pkv50,
|
||||
new_pkv51,
|
||||
new_pkv60,
|
||||
new_pkv61,
|
||||
new_pkv70,
|
||||
new_pkv71,
|
||||
)
|
||||
|
||||
|
||||
class EightLayerLayerFV(torch.nn.Module):
|
||||
def __init__(self, layers):
|
||||
super().__init__()
|
||||
assert len(layers) == 8
|
||||
self.layers = layers
|
||||
|
||||
def forward(self, hidden_states, attention_mask, position_ids):
|
||||
new_pkvs = []
|
||||
for layer in self.layers:
|
||||
outputs = layer(
|
||||
hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=None,
|
||||
use_cache=True,
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
new_pkvs.append(
|
||||
(
|
||||
outputs[-1][0],
|
||||
outputs[-1][1],
|
||||
)
|
||||
)
|
||||
(
|
||||
(new_pkv00, new_pkv01),
|
||||
(new_pkv10, new_pkv11),
|
||||
(new_pkv20, new_pkv21),
|
||||
(new_pkv30, new_pkv31),
|
||||
(new_pkv40, new_pkv41),
|
||||
(new_pkv50, new_pkv51),
|
||||
(new_pkv60, new_pkv61),
|
||||
(new_pkv70, new_pkv71),
|
||||
) = new_pkvs
|
||||
return (
|
||||
hidden_states,
|
||||
new_pkv00,
|
||||
new_pkv01,
|
||||
new_pkv10,
|
||||
new_pkv11,
|
||||
new_pkv20,
|
||||
new_pkv21,
|
||||
new_pkv30,
|
||||
new_pkv31,
|
||||
new_pkv40,
|
||||
new_pkv41,
|
||||
new_pkv50,
|
||||
new_pkv51,
|
||||
new_pkv60,
|
||||
new_pkv61,
|
||||
new_pkv70,
|
||||
new_pkv71,
|
||||
)
|
||||
|
||||
|
||||
class CompiledEightLayerLayerSV(torch.nn.Module):
|
||||
def __init__(self, model):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
past_key_value,
|
||||
output_attentions=False,
|
||||
use_cache=True,
|
||||
):
|
||||
hidden_states = hidden_states.detach()
|
||||
attention_mask = attention_mask.detach()
|
||||
position_ids = position_ids.detach()
|
||||
(
|
||||
(pkv00, pkv01),
|
||||
(pkv10, pkv11),
|
||||
(pkv20, pkv21),
|
||||
(pkv30, pkv31),
|
||||
(pkv40, pkv41),
|
||||
(pkv50, pkv51),
|
||||
(pkv60, pkv61),
|
||||
(pkv70, pkv71),
|
||||
) = past_key_value
|
||||
pkv00 = pkv00.detatch()
|
||||
pkv01 = pkv01.detatch()
|
||||
pkv10 = pkv10.detatch()
|
||||
pkv11 = pkv11.detatch()
|
||||
pkv20 = pkv20.detatch()
|
||||
pkv21 = pkv21.detatch()
|
||||
pkv30 = pkv30.detatch()
|
||||
pkv31 = pkv31.detatch()
|
||||
pkv40 = pkv40.detatch()
|
||||
pkv41 = pkv41.detatch()
|
||||
pkv50 = pkv50.detatch()
|
||||
pkv51 = pkv51.detatch()
|
||||
pkv60 = pkv60.detatch()
|
||||
pkv61 = pkv61.detatch()
|
||||
pkv70 = pkv70.detatch()
|
||||
pkv71 = pkv71.detatch()
|
||||
|
||||
output = self.model(
|
||||
"forward",
|
||||
(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
pkv00,
|
||||
pkv01,
|
||||
pkv10,
|
||||
pkv11,
|
||||
pkv20,
|
||||
pkv21,
|
||||
pkv30,
|
||||
pkv31,
|
||||
pkv40,
|
||||
pkv41,
|
||||
pkv50,
|
||||
pkv51,
|
||||
pkv60,
|
||||
pkv61,
|
||||
pkv70,
|
||||
pkv71,
|
||||
),
|
||||
send_to_host=False,
|
||||
)
|
||||
return (
|
||||
output[0],
|
||||
(output[1][0], output[1][1]),
|
||||
(output[2][0], output[2][1]),
|
||||
(output[3][0], output[3][1]),
|
||||
(output[4][0], output[4][1]),
|
||||
(output[5][0], output[5][1]),
|
||||
(output[6][0], output[6][1]),
|
||||
(output[7][0], output[7][1]),
|
||||
(output[8][0], output[8][1]),
|
||||
)
|
||||
|
||||
|
||||
def forward_compressed(
|
||||
self,
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
):
|
||||
output_attentions = (
|
||||
output_attentions
|
||||
if output_attentions is not None
|
||||
else self.config.output_attentions
|
||||
)
|
||||
output_hidden_states = (
|
||||
output_hidden_states
|
||||
if output_hidden_states is not None
|
||||
else self.config.output_hidden_states
|
||||
)
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
|
||||
return_dict = (
|
||||
return_dict if return_dict is not None else self.config.use_return_dict
|
||||
)
|
||||
|
||||
# retrieve input_ids and inputs_embeds
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
raise ValueError(
|
||||
"You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time"
|
||||
)
|
||||
elif input_ids is not None:
|
||||
batch_size, seq_length = input_ids.shape
|
||||
elif inputs_embeds is not None:
|
||||
batch_size, seq_length, _ = inputs_embeds.shape
|
||||
else:
|
||||
raise ValueError(
|
||||
"You have to specify either decoder_input_ids or decoder_inputs_embeds"
|
||||
)
|
||||
|
||||
seq_length_with_past = seq_length
|
||||
past_key_values_length = 0
|
||||
|
||||
if past_key_values is not None:
|
||||
past_key_values_length = past_key_values[0][0].shape[2]
|
||||
seq_length_with_past = seq_length_with_past + past_key_values_length
|
||||
|
||||
if position_ids is None:
|
||||
device = (
|
||||
input_ids.device if input_ids is not None else inputs_embeds.device
|
||||
)
|
||||
position_ids = torch.arange(
|
||||
past_key_values_length,
|
||||
seq_length + past_key_values_length,
|
||||
dtype=torch.long,
|
||||
device=device,
|
||||
)
|
||||
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
|
||||
else:
|
||||
position_ids = position_ids.view(-1, seq_length).long()
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
# embed positions
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones(
|
||||
(batch_size, seq_length_with_past),
|
||||
dtype=torch.bool,
|
||||
device=inputs_embeds.device,
|
||||
)
|
||||
attention_mask = self._prepare_decoder_attention_mask(
|
||||
attention_mask,
|
||||
(batch_size, seq_length),
|
||||
inputs_embeds,
|
||||
past_key_values_length,
|
||||
)
|
||||
|
||||
hidden_states = inputs_embeds
|
||||
|
||||
# decoder layers
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_self_attns = () if output_attentions else None
|
||||
next_decoder_cache = () if use_cache else None
|
||||
|
||||
for idx, decoder_layer in enumerate(self.compressedlayers):
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
past_key_value = (
|
||||
past_key_values[8 * idx : 8 * (idx + 1)]
|
||||
if past_key_values is not None
|
||||
else None
|
||||
)
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
# None for past_key_value
|
||||
return module(*inputs, output_attentions, None)
|
||||
|
||||
return custom_forward
|
||||
|
||||
layer_outputs = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(decoder_layer),
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
None,
|
||||
)
|
||||
else:
|
||||
layer_outputs = decoder_layer(
|
||||
hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
if use_cache:
|
||||
next_decoder_cache += (
|
||||
layer_outputs[2 if output_attentions else 1],
|
||||
)
|
||||
|
||||
if output_attentions:
|
||||
all_self_attns += (layer_outputs[1],)
|
||||
|
||||
hidden_states = self.norm(hidden_states)
|
||||
|
||||
# add hidden states from the last decoder layer
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
next_cache = next_decoder_cache if use_cache else None
|
||||
if not return_dict:
|
||||
return tuple(
|
||||
v
|
||||
for v in [
|
||||
hidden_states,
|
||||
next_cache,
|
||||
all_hidden_states,
|
||||
all_self_attns,
|
||||
]
|
||||
if v is not None
|
||||
)
|
||||
return BaseModelOutputWithPast(
|
||||
last_hidden_state=hidden_states,
|
||||
past_key_values=next_cache,
|
||||
hidden_states=all_hidden_states,
|
||||
attentions=all_self_attns,
|
||||
)
|
||||
|
||||
|
||||
class CompiledEightLayerLayer(torch.nn.Module):
|
||||
def __init__(self, model):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
past_key_value=None,
|
||||
output_attentions=False,
|
||||
use_cache=True,
|
||||
):
|
||||
t2 = time()
|
||||
if past_key_value is None:
|
||||
try:
|
||||
hidden_states = np.asarray(hidden_states, hidden_states.dtype)
|
||||
except:
|
||||
pass
|
||||
attention_mask = attention_mask.detach()
|
||||
position_ids = position_ids.detach()
|
||||
t1 = time()
|
||||
|
||||
output = self.model(
|
||||
"first_vicuna_forward",
|
||||
(hidden_states, attention_mask, position_ids),
|
||||
send_to_host=False,
|
||||
)
|
||||
output2 = (
|
||||
output[0],
|
||||
(
|
||||
output[1],
|
||||
output[2],
|
||||
),
|
||||
(
|
||||
output[3],
|
||||
output[4],
|
||||
),
|
||||
(
|
||||
output[5],
|
||||
output[6],
|
||||
),
|
||||
(
|
||||
output[7],
|
||||
output[8],
|
||||
),
|
||||
(
|
||||
output[9],
|
||||
output[10],
|
||||
),
|
||||
(
|
||||
output[11],
|
||||
output[12],
|
||||
),
|
||||
(
|
||||
output[13],
|
||||
output[14],
|
||||
),
|
||||
(
|
||||
output[15],
|
||||
output[16],
|
||||
),
|
||||
)
|
||||
return output2
|
||||
else:
|
||||
(
|
||||
(pkv00, pkv01),
|
||||
(pkv10, pkv11),
|
||||
(pkv20, pkv21),
|
||||
(pkv30, pkv31),
|
||||
(pkv40, pkv41),
|
||||
(pkv50, pkv51),
|
||||
(pkv60, pkv61),
|
||||
(pkv70, pkv71),
|
||||
) = past_key_value
|
||||
|
||||
try:
|
||||
hidden_states = hidden_states.detach()
|
||||
attention_mask = attention_mask.detach()
|
||||
position_ids = position_ids.detach()
|
||||
pkv00 = pkv00.detach()
|
||||
pkv01 = pkv01.detach()
|
||||
pkv10 = pkv10.detach()
|
||||
pkv11 = pkv11.detach()
|
||||
pkv20 = pkv20.detach()
|
||||
pkv21 = pkv21.detach()
|
||||
pkv30 = pkv30.detach()
|
||||
pkv31 = pkv31.detach()
|
||||
pkv40 = pkv40.detach()
|
||||
pkv41 = pkv41.detach()
|
||||
pkv50 = pkv50.detach()
|
||||
pkv51 = pkv51.detach()
|
||||
pkv60 = pkv60.detach()
|
||||
pkv61 = pkv61.detach()
|
||||
pkv70 = pkv70.detach()
|
||||
pkv71 = pkv71.detach()
|
||||
except:
|
||||
x = 10
|
||||
|
||||
t1 = time()
|
||||
if type(hidden_states) == iree.runtime.array_interop.DeviceArray:
|
||||
hidden_states = np.array(hidden_states, hidden_states.dtype)
|
||||
hidden_states = torch.tensor(hidden_states)
|
||||
hidden_states = hidden_states.detach()
|
||||
|
||||
output = self.model(
|
||||
"second_vicuna_forward",
|
||||
(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
pkv00,
|
||||
pkv01,
|
||||
pkv10,
|
||||
pkv11,
|
||||
pkv20,
|
||||
pkv21,
|
||||
pkv30,
|
||||
pkv31,
|
||||
pkv40,
|
||||
pkv41,
|
||||
pkv50,
|
||||
pkv51,
|
||||
pkv60,
|
||||
pkv61,
|
||||
pkv70,
|
||||
pkv71,
|
||||
),
|
||||
send_to_host=False,
|
||||
)
|
||||
print(f"{time() - t1}")
|
||||
del pkv00
|
||||
del pkv01
|
||||
del pkv10
|
||||
del pkv11
|
||||
del pkv20
|
||||
del pkv21
|
||||
del pkv30
|
||||
del pkv31
|
||||
del pkv40
|
||||
del pkv41
|
||||
del pkv50
|
||||
del pkv51
|
||||
del pkv60
|
||||
del pkv61
|
||||
del pkv70
|
||||
del pkv71
|
||||
output2 = (
|
||||
output[0],
|
||||
(
|
||||
output[1],
|
||||
output[2],
|
||||
),
|
||||
(
|
||||
output[3],
|
||||
output[4],
|
||||
),
|
||||
(
|
||||
output[5],
|
||||
output[6],
|
||||
),
|
||||
(
|
||||
output[7],
|
||||
output[8],
|
||||
),
|
||||
(
|
||||
output[9],
|
||||
output[10],
|
||||
),
|
||||
(
|
||||
output[11],
|
||||
output[12],
|
||||
),
|
||||
(
|
||||
output[13],
|
||||
output[14],
|
||||
),
|
||||
(
|
||||
output[15],
|
||||
output[16],
|
||||
),
|
||||
)
|
||||
return output2
|
||||
@@ -301,13 +301,12 @@ class CombinedModel(torch.nn.Module):
|
||||
self.second_vicuna = SecondVicuna(second_vicuna_model_path)
|
||||
|
||||
def forward(self, input_ids):
|
||||
first_output = self.first_vicuna(input_ids=input_ids)
|
||||
# generate second vicuna
|
||||
compilation_input_ids = torch.zeros([1, 1], dtype=torch.int64)
|
||||
pkv = tuple(
|
||||
(torch.zeros([1, 32, 19, 128], dtype=torch.float32))
|
||||
for _ in range(64)
|
||||
)
|
||||
secondVicunaCompileInput = (compilation_input_ids,) + pkv
|
||||
second_output = self.second_vicuna(*secondVicunaCompileInput)
|
||||
first_output = self.first_vicuna(input_ids=input_ids, use_cache=True)
|
||||
logits = first_output[0]
|
||||
pkv = first_output[1:]
|
||||
|
||||
token = torch.argmax(torch.tensor(logits)[:, -1, :], dim=1)
|
||||
token = token.to(torch.int64).reshape([1, 1])
|
||||
secondVicunaInput = (token,) + tuple(pkv)
|
||||
second_output = self.second_vicuna(secondVicunaInput)
|
||||
return second_output
|
||||
|
||||
@@ -66,7 +66,7 @@ class ShardedVicunaModel(torch.nn.Module):
|
||||
def __init__(self, model, layers, lmhead, embedding, norm):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
# assert len(layers) == len(model.model.layers)
|
||||
assert len(layers) == len(model.model.layers)
|
||||
self.model.model.config.use_cache = True
|
||||
self.model.model.config.output_attentions = False
|
||||
self.layers = layers
|
||||
@@ -132,10 +132,7 @@ class VicunaNormCompiled(torch.nn.Module):
|
||||
self.model = shark_module
|
||||
|
||||
def forward(self, hidden_states):
|
||||
try:
|
||||
hidden_states.detach()
|
||||
except:
|
||||
pass
|
||||
hidden_states.detach()
|
||||
output = self.model("forward", (hidden_states,))
|
||||
output = torch.tensor(output)
|
||||
return output
|
||||
|
||||
@@ -136,8 +136,7 @@ from brevitas_examples.llm.llm_quant.quantize import quantize_model
|
||||
from brevitas_examples.llm.llm_quant.run_utils import get_model_impl
|
||||
|
||||
|
||||
# fmt: off
|
||||
def quant〇matmul_rhs_group_quant〡shape(lhs: List[int], rhs: List[int], rhs_scale: List[int], rhs_zero_point: List[int], rhs_bit_width: int, rhs_group_size: int) -> List[int]:
|
||||
def brevitas〇matmul_rhs_group_quant〡shape(lhs: List[int], rhs: List[int], rhs_scale: List[int], rhs_zero_point: List[int], rhs_bit_width: int, rhs_group_size: int) -> List[int]:
|
||||
if len(lhs) == 3 and len(rhs) == 2:
|
||||
return [lhs[0], lhs[1], rhs[0]]
|
||||
elif len(lhs) == 2 and len(rhs) == 2:
|
||||
@@ -146,21 +145,20 @@ def quant〇matmul_rhs_group_quant〡shape(lhs: List[int], rhs: List[int], rhs_s
|
||||
raise ValueError("Input shapes not supported.")
|
||||
|
||||
|
||||
def quant〇matmul_rhs_group_quant〡dtype(lhs_rank_dtype: Tuple[int, int], rhs_rank_dtype: Tuple[int, int], rhs_scale_rank_dtype: Tuple[int, int], rhs_zero_point_rank_dtype: Tuple[int, int], rhs_bit_width: int, rhs_group_size: int) -> int:
|
||||
def brevitas〇matmul_rhs_group_quant〡dtype(lhs_rank_dtype: Tuple[int, int], rhs_rank_dtype: Tuple[int, int], rhs_scale_rank_dtype: Tuple[int, int], rhs_zero_point_rank_dtype: Tuple[int, int], rhs_bit_width: int, rhs_group_size: int) -> int:
|
||||
# output dtype is the dtype of the lhs float input
|
||||
lhs_rank, lhs_dtype = lhs_rank_dtype
|
||||
return lhs_dtype
|
||||
|
||||
|
||||
def quant〇matmul_rhs_group_quant〡has_value_semantics(lhs, rhs, rhs_scale, rhs_zero_point, rhs_bit_width, rhs_group_size) -> None:
|
||||
def brevitas〇matmul_rhs_group_quant〡has_value_semantics(lhs, rhs, rhs_scale, rhs_zero_point, rhs_bit_width, rhs_group_size) -> None:
|
||||
return
|
||||
|
||||
|
||||
brevitas_matmul_rhs_group_quant_library = [
|
||||
quant〇matmul_rhs_group_quant〡shape,
|
||||
quant〇matmul_rhs_group_quant〡dtype,
|
||||
quant〇matmul_rhs_group_quant〡has_value_semantics]
|
||||
# fmt: on
|
||||
brevitas〇matmul_rhs_group_quant〡shape,
|
||||
brevitas〇matmul_rhs_group_quant〡dtype,
|
||||
brevitas〇matmul_rhs_group_quant〡has_value_semantics]
|
||||
|
||||
|
||||
def load_vmfb(extended_model_name, device, mlir_dialect, extra_args=[]):
|
||||
@@ -211,7 +209,7 @@ def compile_int_precision(
|
||||
torchscript_module,
|
||||
inputs,
|
||||
output_type="torch",
|
||||
backend_legal_ops=["quant.matmul_rhs_group_quant"],
|
||||
backend_legal_ops=["brevitas.matmul_rhs_group_quant"],
|
||||
extra_library=brevitas_matmul_rhs_group_quant_library,
|
||||
use_tracing=False,
|
||||
verbose=False,
|
||||
|
||||
@@ -34,7 +34,7 @@ from PIL import Image
|
||||
from tqdm.auto import tqdm
|
||||
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
||||
from diffusers.loaders import AttnProcsLayers
|
||||
from diffusers.models.attention_processor import LoRAXFormersAttnProcessor
|
||||
from diffusers.models.cross_attention import LoRACrossAttnProcessor
|
||||
|
||||
import torch_mlir
|
||||
from torch_mlir.dynamo import make_simple_dynamo_backend
|
||||
@@ -287,7 +287,7 @@ def lora_train(
|
||||
block_id = int(name[len("down_blocks.")])
|
||||
hidden_size = unet.config.block_out_channels[block_id]
|
||||
|
||||
lora_attn_procs[name] = LoRAXFormersAttnProcessor(
|
||||
lora_attn_procs[name] = LoRACrossAttnProcessor(
|
||||
hidden_size=hidden_size,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
)
|
||||
|
||||
@@ -30,7 +30,6 @@ datas += copy_metadata("safetensors")
|
||||
datas += copy_metadata("Pillow")
|
||||
datas += copy_metadata("sentencepiece")
|
||||
datas += copy_metadata("pyyaml")
|
||||
datas += copy_metadata("huggingface-hub")
|
||||
datas += collect_data_files("tokenizers")
|
||||
datas += collect_data_files("tiktoken")
|
||||
datas += collect_data_files("accelerate")
|
||||
@@ -43,7 +42,7 @@ datas += collect_data_files("gradio")
|
||||
datas += collect_data_files("gradio_client")
|
||||
datas += collect_data_files("iree")
|
||||
datas += collect_data_files("google_cloud_storage")
|
||||
datas += collect_data_files("shark", include_py_files=True)
|
||||
datas += collect_data_files("shark")
|
||||
datas += collect_data_files("timm", include_py_files=True)
|
||||
datas += collect_data_files("tkinter")
|
||||
datas += collect_data_files("webview")
|
||||
@@ -51,7 +50,6 @@ datas += collect_data_files("sentencepiece")
|
||||
datas += collect_data_files("jsonschema")
|
||||
datas += collect_data_files("jsonschema_specifications")
|
||||
datas += collect_data_files("cpuinfo")
|
||||
datas += collect_data_files("langchain")
|
||||
datas += [
|
||||
("src/utils/resources/prompts.json", "resources"),
|
||||
("src/utils/resources/model_db.json", "resources"),
|
||||
@@ -77,4 +75,3 @@ hiddenimports += [
|
||||
x for x in collect_submodules("transformers") if "tests" not in x
|
||||
]
|
||||
hiddenimports += [x for x in collect_submodules("iree") if "tests" not in x]
|
||||
hiddenimports += ["iree._runtime", "iree._runtime_libs"]
|
||||
|
||||
@@ -109,7 +109,7 @@ def load_lower_configs(base_model_id=None):
|
||||
spec = spec.split("-")[0]
|
||||
|
||||
if args.annotation_model == "vae":
|
||||
if not spec or spec in ["sm_80"]:
|
||||
if not spec or spec in ["rdna3", "sm_80"]:
|
||||
config_name = (
|
||||
f"{args.annotation_model}_{args.precision}_{device}.json"
|
||||
)
|
||||
@@ -281,13 +281,9 @@ def sd_model_annotation(mlir_model, model_name, base_model_id=None):
|
||||
if "rdna2" not in args.iree_vulkan_target_triple.split("-")[0]:
|
||||
use_winograd = True
|
||||
winograd_config_dir = load_winograd_configs()
|
||||
winograd_model = annotate_with_winograd(
|
||||
tuned_model = annotate_with_winograd(
|
||||
mlir_model, winograd_config_dir, model_name
|
||||
)
|
||||
lowering_config_dir = load_lower_configs(base_model_id)
|
||||
tuned_model = annotate_with_lower_configs(
|
||||
winograd_model, lowering_config_dir, model_name, use_winograd
|
||||
)
|
||||
else:
|
||||
tuned_model = mlir_model
|
||||
else:
|
||||
|
||||
@@ -519,12 +519,6 @@ p.add_argument(
|
||||
"in shark importer. Does nothing if import_mlir is false (the default).",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--iree_constant_folding",
|
||||
default=True,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="Controls constant folding in iree-compile for all SD models.",
|
||||
)
|
||||
|
||||
##############################################################################
|
||||
# Web UI flags
|
||||
|
||||
@@ -500,12 +500,6 @@ def get_opt_flags(model, precision="fp16"):
|
||||
f"-iree-vulkan-target-triple={args.iree_vulkan_target_triple}"
|
||||
)
|
||||
|
||||
if args.iree_constant_folding == False:
|
||||
iree_flags.append("--iree-opt-const-expr-hoisting=False")
|
||||
iree_flags.append(
|
||||
"--iree-codegen-linalg-max-constant-fold-elements=9223372036854775807"
|
||||
)
|
||||
|
||||
# Disable bindings fusion to work with moltenVK.
|
||||
if sys.platform == "darwin":
|
||||
iree_flags.append("-iree-stream-fuse-binding=false")
|
||||
|
||||
@@ -37,7 +37,7 @@ def launch_app(address):
|
||||
height=height,
|
||||
text_select=True,
|
||||
)
|
||||
webview.start(private_mode=False, storage_path=os.getcwd())
|
||||
webview.start(private_mode=False)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
@@ -115,8 +115,7 @@ if __name__ == "__main__":
|
||||
txt2img_sendto_inpaint,
|
||||
txt2img_sendto_outpaint,
|
||||
txt2img_sendto_upscaler,
|
||||
# h2ogpt_upload,
|
||||
# h2ogpt_web,
|
||||
h2ogpt_web,
|
||||
img2img_web,
|
||||
img2img_custom_model,
|
||||
img2img_hf_model_id,
|
||||
@@ -155,7 +154,6 @@ if __name__ == "__main__":
|
||||
upscaler_sendto_outpaint,
|
||||
lora_train_web,
|
||||
model_web,
|
||||
model_config_web,
|
||||
hf_models,
|
||||
modelmanager_sendto_txt2img,
|
||||
modelmanager_sendto_img2img,
|
||||
@@ -213,15 +211,6 @@ if __name__ == "__main__":
|
||||
css=dark_theme, analytics_enabled=False, title="Stable Diffusion"
|
||||
) as sd_web:
|
||||
with gr.Tabs() as tabs:
|
||||
# NOTE: If adding, removing, or re-ordering tabs, make sure that they
|
||||
# have a unique id that doesn't clash with any of the other tabs,
|
||||
# and that the order in the code here is the order they should
|
||||
# appear in the ui, as the id value doesn't determine the order.
|
||||
|
||||
# Where possible, avoid changing the id of any tab that is the
|
||||
# destination of one of the 'send to' buttons. If you do have to change
|
||||
# that id, make sure you update the relevant register_button_click calls
|
||||
# further down with the new id.
|
||||
with gr.TabItem(label="Text-to-Image", id=0):
|
||||
txt2img_web.render()
|
||||
with gr.TabItem(label="Image-to-Image", id=1):
|
||||
@@ -249,20 +238,14 @@ if __name__ == "__main__":
|
||||
)
|
||||
with gr.TabItem(label="Model Manager", id=6):
|
||||
model_web.render()
|
||||
with gr.TabItem(label="LoRA Training (Experimental)", id=7):
|
||||
with gr.TabItem(label="LoRA Training (Experimental)", id=8):
|
||||
lora_train_web.render()
|
||||
with gr.TabItem(label="Chat Bot (Experimental)", id=8):
|
||||
with gr.TabItem(label="Chat Bot (Experimental)", id=7):
|
||||
stablelm_chat.render()
|
||||
with gr.TabItem(
|
||||
label="Generate Sharding Config (Experimental)", id=9
|
||||
):
|
||||
model_config_web.render()
|
||||
with gr.TabItem(label="MultiModal (Experimental)", id=10):
|
||||
with gr.TabItem(label="MultiModal (Experimental)", id=9):
|
||||
minigpt4_web.render()
|
||||
# with gr.TabItem(label="DocuChat Upload", id=11):
|
||||
# h2ogpt_upload.render()
|
||||
# with gr.TabItem(label="DocuChat(Experimental)", id=12):
|
||||
# h2ogpt_web.render()
|
||||
with gr.TabItem(label="DocuChat(Experimental)", id=10):
|
||||
h2ogpt_web.render()
|
||||
|
||||
# send to buttons
|
||||
register_button_click(
|
||||
|
||||
@@ -78,7 +78,7 @@ from apps.stable_diffusion.web.ui.stablelm_ui import (
|
||||
stablelm_chat,
|
||||
llm_chat_api,
|
||||
)
|
||||
from apps.stable_diffusion.web.ui.generate_config import model_config_web
|
||||
from apps.stable_diffusion.web.ui.h2ogpt import h2ogpt_web
|
||||
from apps.stable_diffusion.web.ui.minigpt4_ui import minigpt4_web
|
||||
from apps.stable_diffusion.web.ui.outputgallery_ui import (
|
||||
outputgallery_web,
|
||||
|
||||
@@ -1,41 +0,0 @@
|
||||
import gradio as gr
|
||||
import torch
|
||||
from transformers import AutoTokenizer
|
||||
from apps.language_models.src.model_wrappers.vicuna_model import CombinedModel
|
||||
from shark.shark_generate_model_config import GenerateConfigFile
|
||||
|
||||
|
||||
def get_model_config():
|
||||
hf_model_path = "TheBloke/vicuna-7B-1.1-HF"
|
||||
tokenizer = AutoTokenizer.from_pretrained(hf_model_path, use_fast=False)
|
||||
compilation_prompt = "".join(["0" for _ in range(17)])
|
||||
compilation_input_ids = tokenizer(
|
||||
compilation_prompt,
|
||||
return_tensors="pt",
|
||||
).input_ids
|
||||
compilation_input_ids = torch.tensor(compilation_input_ids).reshape(
|
||||
[1, 19]
|
||||
)
|
||||
firstVicunaCompileInput = (compilation_input_ids,)
|
||||
|
||||
model = CombinedModel()
|
||||
c = GenerateConfigFile(model, 1, ["gpu_id"], firstVicunaCompileInput)
|
||||
return c.split_into_layers()
|
||||
|
||||
|
||||
with gr.Blocks() as model_config_web:
|
||||
with gr.Row():
|
||||
hf_models = gr.Dropdown(
|
||||
label="Model List",
|
||||
choices=["Vicuna"],
|
||||
value="Vicuna",
|
||||
visible=True,
|
||||
)
|
||||
get_model_config_btn = gr.Button(value="Get Model Config")
|
||||
json_view = gr.JSON()
|
||||
|
||||
get_model_config_btn.click(
|
||||
fn=get_model_config,
|
||||
inputs=[],
|
||||
outputs=[json_view],
|
||||
)
|
||||
@@ -12,10 +12,6 @@ from apps.language_models.langchain.enums import (
|
||||
LangChainAction,
|
||||
)
|
||||
import apps.language_models.langchain.gen as gen
|
||||
from gpt_langchain import (
|
||||
path_to_docs,
|
||||
create_or_update_db,
|
||||
)
|
||||
from apps.stable_diffusion.src import args
|
||||
|
||||
|
||||
@@ -37,15 +33,8 @@ start_message = """
|
||||
|
||||
def create_prompt(history):
|
||||
system_message = start_message
|
||||
for item in history:
|
||||
print("His item: ", item)
|
||||
|
||||
conversation = "<|endoftext|>".join(
|
||||
[
|
||||
"<|endoftext|><|answer|>".join([item[0], item[1]])
|
||||
for item in history
|
||||
]
|
||||
)
|
||||
conversation = "".join(["".join([item[0], item[1]]) for item in history])
|
||||
|
||||
msg = system_message + conversation
|
||||
msg = msg.strip()
|
||||
@@ -55,12 +44,10 @@ def create_prompt(history):
|
||||
def chat(curr_system_message, history, device, precision):
|
||||
args.run_docuchat_web = True
|
||||
global h2ogpt_model
|
||||
global sharkModel
|
||||
global h2ogpt_tokenizer
|
||||
global model_state
|
||||
global langchain
|
||||
global userpath_selector
|
||||
from apps.language_models.langchain.h2oai_pipeline import generate_token
|
||||
|
||||
if h2ogpt_model == 0:
|
||||
if "cuda" in device:
|
||||
@@ -115,14 +102,9 @@ def chat(curr_system_message, history, device, precision):
|
||||
prompt_type=None,
|
||||
prompt_dict=None,
|
||||
)
|
||||
from apps.language_models.langchain.h2oai_pipeline import (
|
||||
H2OGPTSHARKModel,
|
||||
)
|
||||
|
||||
sharkModel = H2OGPTSHARKModel()
|
||||
|
||||
prompt = create_prompt(history)
|
||||
output_dict = langchain.evaluate(
|
||||
output = langchain.evaluate(
|
||||
model_state=model_state,
|
||||
my_db_state=None,
|
||||
instruction=prompt,
|
||||
@@ -182,22 +164,14 @@ def chat(curr_system_message, history, device, precision):
|
||||
model_lock=True,
|
||||
user_path=userpath_selector.value,
|
||||
)
|
||||
|
||||
output = generate_token(sharkModel, **output_dict)
|
||||
for partial_text in output:
|
||||
history[-1][1] = partial_text
|
||||
history[-1][1] = partial_text["response"]
|
||||
yield history
|
||||
|
||||
return history
|
||||
|
||||
|
||||
userpath_selector = gr.Textbox(
|
||||
label="Document Directory",
|
||||
value=str(os.path.abspath("apps/language_models/langchain/user_path/")),
|
||||
interactive=True,
|
||||
container=True,
|
||||
)
|
||||
|
||||
with gr.Blocks(title="DocuChat") as h2ogpt_web:
|
||||
with gr.Blocks(title="H2OGPT") as h2ogpt_web:
|
||||
with gr.Row():
|
||||
supported_devices = available_devices
|
||||
enabled = len(supported_devices) > 0
|
||||
@@ -224,6 +198,14 @@ with gr.Blocks(title="DocuChat") as h2ogpt_web:
|
||||
],
|
||||
visible=True,
|
||||
)
|
||||
userpath_selector = gr.Textbox(
|
||||
label="Document Directory",
|
||||
value=str(
|
||||
os.path.abspath("apps/language_models/langchain/user_path/")
|
||||
),
|
||||
interactive=True,
|
||||
container=True,
|
||||
)
|
||||
chatbot = gr.Chatbot(height=500)
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
@@ -267,100 +249,3 @@ with gr.Blocks(title="DocuChat") as h2ogpt_web:
|
||||
queue=False,
|
||||
)
|
||||
clear.click(lambda: None, None, [chatbot], queue=False)
|
||||
|
||||
|
||||
with gr.Blocks(title="DocuChat Upload") as h2ogpt_upload:
|
||||
import pathlib
|
||||
|
||||
upload_path = None
|
||||
database = None
|
||||
database_directory = os.path.abspath(
|
||||
"apps/language_models/langchain/db_path/"
|
||||
)
|
||||
|
||||
def read_path():
|
||||
global upload_path
|
||||
filenames = [
|
||||
[f]
|
||||
for f in os.listdir(upload_path)
|
||||
if os.path.isfile(os.path.join(upload_path, f))
|
||||
]
|
||||
filenames.sort()
|
||||
return filenames
|
||||
|
||||
def upload_file(f):
|
||||
names = []
|
||||
for tmpfile in f:
|
||||
name = tmpfile.name.split("/")[-1]
|
||||
basename = os.path.join(upload_path, name)
|
||||
with open(basename, "wb") as w:
|
||||
with open(tmpfile.name, "rb") as r:
|
||||
w.write(r.read())
|
||||
update_or_create_db()
|
||||
return read_path()
|
||||
|
||||
def update_userpath(newpath):
|
||||
global upload_path
|
||||
upload_path = newpath
|
||||
pathlib.Path(upload_path).mkdir(parents=True, exist_ok=True)
|
||||
return read_path()
|
||||
|
||||
def update_or_create_db():
|
||||
global database
|
||||
global upload_path
|
||||
|
||||
sources = path_to_docs(
|
||||
upload_path,
|
||||
verbose=True,
|
||||
fail_any_exception=False,
|
||||
n_jobs=-1,
|
||||
chunk=True,
|
||||
chunk_size=512,
|
||||
url=None,
|
||||
enable_captions=False,
|
||||
captions_model=None,
|
||||
caption_loader=None,
|
||||
enable_ocr=False,
|
||||
)
|
||||
|
||||
pathlib.Path(database_directory).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
database = create_or_update_db(
|
||||
"chroma",
|
||||
database_directory,
|
||||
"UserData",
|
||||
sources,
|
||||
False,
|
||||
True,
|
||||
True,
|
||||
"sentence-transformers/all-MiniLM-L6-v2",
|
||||
)
|
||||
|
||||
def first_run():
|
||||
global database
|
||||
if database is None:
|
||||
update_or_create_db()
|
||||
|
||||
update_userpath(
|
||||
os.path.abspath("apps/language_models/langchain/user_path/")
|
||||
)
|
||||
h2ogpt_upload.load(fn=first_run)
|
||||
h2ogpt_web.load(fn=first_run)
|
||||
|
||||
with gr.Column():
|
||||
text = gr.DataFrame(
|
||||
col_count=(1, "fixed"),
|
||||
type="array",
|
||||
label="Documents",
|
||||
value=read_path(),
|
||||
)
|
||||
with gr.Row():
|
||||
upload = gr.UploadButton(
|
||||
label="Upload documents",
|
||||
file_count="multiple",
|
||||
)
|
||||
upload.upload(fn=upload_file, inputs=upload, outputs=text)
|
||||
userpath_selector.render()
|
||||
userpath_selector.input(
|
||||
fn=update_userpath, inputs=userpath_selector, outputs=text
|
||||
).then(fn=update_or_create_db)
|
||||
|
||||
@@ -7,8 +7,6 @@ from transformers import (
|
||||
)
|
||||
from apps.stable_diffusion.web.ui.utils import available_devices
|
||||
from datetime import datetime as dt
|
||||
import json
|
||||
import time
|
||||
|
||||
|
||||
def user(message, history):
|
||||
@@ -28,7 +26,6 @@ model_map = {
|
||||
"codegen": "Salesforce/codegen25-7b-multi",
|
||||
"vicuna1p3": "lmsys/vicuna-7b-v1.3",
|
||||
"vicuna": "TheBloke/vicuna-7B-1.1-HF",
|
||||
"vicuna4": "TheBloke/vicuna-7B-1.1-HF",
|
||||
"StableLM": "stabilityai/stablelm-tuned-alpha-3b",
|
||||
}
|
||||
|
||||
@@ -68,11 +65,6 @@ start_message = {
|
||||
"The assistant gives helpful, detailed, and polite answers to the user's "
|
||||
"questions.\n"
|
||||
),
|
||||
"vicuna4": (
|
||||
"A chat between a curious user and an artificial intelligence assistant. "
|
||||
"The assistant gives helpful, detailed, and polite answers to the user's "
|
||||
"questions.\n"
|
||||
),
|
||||
"vicuna1p3": (
|
||||
"A chat between a curious user and an artificial intelligence assistant. "
|
||||
"The assistant gives helpful, detailed, and polite answers to the user's "
|
||||
@@ -88,7 +80,6 @@ def create_prompt(model_name, history):
|
||||
if model_name in [
|
||||
"StableLM",
|
||||
"vicuna",
|
||||
"vicuna4",
|
||||
"vicuna1p3",
|
||||
"llama2_7b",
|
||||
"llama2_70b",
|
||||
@@ -114,144 +105,53 @@ def set_vicuna_model(model):
|
||||
vicuna_model = model
|
||||
|
||||
|
||||
def get_default_config():
|
||||
import torch
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
hf_model_path = "TheBloke/vicuna-7B-1.1-HF"
|
||||
tokenizer = AutoTokenizer.from_pretrained(hf_model_path, use_fast=False)
|
||||
compilation_prompt = "".join(["0" for _ in range(17)])
|
||||
compilation_input_ids = tokenizer(
|
||||
compilation_prompt,
|
||||
return_tensors="pt",
|
||||
).input_ids
|
||||
compilation_input_ids = torch.tensor(compilation_input_ids).reshape(
|
||||
[1, 19]
|
||||
)
|
||||
firstVicunaCompileInput = (compilation_input_ids,)
|
||||
from apps.language_models.src.model_wrappers.vicuna_model import (
|
||||
CombinedModel,
|
||||
)
|
||||
from shark.shark_generate_model_config import GenerateConfigFile
|
||||
|
||||
model = CombinedModel()
|
||||
c = GenerateConfigFile(model, 1, ["gpu_id"], firstVicunaCompileInput)
|
||||
c.split_into_layers()
|
||||
|
||||
|
||||
model_vmfb_key = ""
|
||||
|
||||
|
||||
# TODO: Make chat reusable for UI and API
|
||||
def chat(
|
||||
curr_system_message,
|
||||
history,
|
||||
model,
|
||||
device,
|
||||
precision,
|
||||
config_file,
|
||||
cli=False,
|
||||
progress=gr.Progress(),
|
||||
):
|
||||
def chat(curr_system_message, history, model, device, precision, cli=True):
|
||||
global past_key_values
|
||||
global model_vmfb_key
|
||||
|
||||
global vicuna_model
|
||||
model_name, model_path = list(map(str.strip, model.split("=>")))
|
||||
if "cuda" in device:
|
||||
device = "cuda"
|
||||
elif "sync" in device:
|
||||
device = "cpu-sync"
|
||||
elif "task" in device:
|
||||
device = "cpu-task"
|
||||
elif "vulkan" in device:
|
||||
device = "vulkan"
|
||||
else:
|
||||
print("unrecognized device")
|
||||
|
||||
new_model_vmfb_key = f"{model_name}#{model_path}#{device}#{precision}"
|
||||
if model_name in [
|
||||
"vicuna",
|
||||
"vicuna4",
|
||||
"vicuna1p3",
|
||||
"codegen",
|
||||
"llama2_7b",
|
||||
"llama2_70b",
|
||||
]:
|
||||
from apps.language_models.scripts.vicuna import ShardedVicuna
|
||||
from apps.language_models.scripts.vicuna import UnshardedVicuna
|
||||
from apps.language_models.scripts.vicuna import (
|
||||
UnshardedVicuna,
|
||||
)
|
||||
from apps.stable_diffusion.src import args
|
||||
|
||||
if new_model_vmfb_key != model_vmfb_key:
|
||||
model_vmfb_key = new_model_vmfb_key
|
||||
max_toks = 128 if model_name == "codegen" else 512
|
||||
|
||||
# get iree flags that need to be overridden, from commandline args
|
||||
_extra_args = []
|
||||
# vulkan target triple
|
||||
if args.iree_vulkan_target_triple != "":
|
||||
_extra_args.append(
|
||||
f"-iree-vulkan-target-triple={args.iree_vulkan_target_triple}"
|
||||
)
|
||||
|
||||
if model_name == "vicuna4":
|
||||
vicuna_model = ShardedVicuna(
|
||||
model_name,
|
||||
hf_model_path=model_path,
|
||||
device=device,
|
||||
precision=precision,
|
||||
max_num_tokens=max_toks,
|
||||
compressed=True,
|
||||
extra_args_cmd=_extra_args,
|
||||
)
|
||||
if vicuna_model == 0:
|
||||
if "cuda" in device:
|
||||
device = "cuda"
|
||||
elif "sync" in device:
|
||||
device = "cpu-sync"
|
||||
elif "task" in device:
|
||||
device = "cpu-task"
|
||||
elif "vulkan" in device:
|
||||
device = "vulkan"
|
||||
else:
|
||||
# if config_file is None:
|
||||
vicuna_model = UnshardedVicuna(
|
||||
model_name,
|
||||
hf_model_path=model_path,
|
||||
hf_auth_token=args.hf_auth_token,
|
||||
device=device,
|
||||
precision=precision,
|
||||
max_num_tokens=max_toks,
|
||||
extra_args_cmd=_extra_args,
|
||||
)
|
||||
# else:
|
||||
# if config_file is not None:
|
||||
# config_file = open(config_file)
|
||||
# config_json = json.load(config_file)
|
||||
# config_file.close()
|
||||
# else:
|
||||
# config_json = get_default_config()
|
||||
# vicuna_model = ShardedVicuna(
|
||||
# model_name,
|
||||
# device=device,
|
||||
# precision=precision,
|
||||
# config_json=config_json,
|
||||
# )
|
||||
print("unrecognized device")
|
||||
|
||||
max_toks = 128 if model_name == "codegen" else 512
|
||||
vicuna_model = UnshardedVicuna(
|
||||
model_name,
|
||||
hf_model_path=model_path,
|
||||
hf_auth_token=args.hf_auth_token,
|
||||
device=device,
|
||||
precision=precision,
|
||||
max_num_tokens=max_toks,
|
||||
)
|
||||
prompt = create_prompt(model_name, history)
|
||||
|
||||
partial_text = ""
|
||||
count = 0
|
||||
start_time = time.time()
|
||||
for text, msg in progress.tqdm(
|
||||
vicuna_model.generate(prompt, cli=cli),
|
||||
desc="generating response",
|
||||
):
|
||||
count += 1
|
||||
if "formatted" in msg:
|
||||
history[-1][1] = text
|
||||
end_time = time.time()
|
||||
tokens_per_sec = count / (end_time - start_time)
|
||||
yield history, str(
|
||||
format(tokens_per_sec, ".2f")
|
||||
) + " tokens/sec"
|
||||
else:
|
||||
partial_text += text + " "
|
||||
history[-1][1] = partial_text
|
||||
yield history, ""
|
||||
for partial_text in vicuna_model.generate(prompt, cli=cli):
|
||||
history[-1][1] = partial_text
|
||||
yield history
|
||||
|
||||
return history, ""
|
||||
return history
|
||||
|
||||
# else Model is StableLM
|
||||
global sharkModel
|
||||
@@ -259,8 +159,7 @@ def chat(
|
||||
SharkStableLM,
|
||||
)
|
||||
|
||||
if new_model_vmfb_key != model_vmfb_key:
|
||||
model_vmfb_key = new_model_vmfb_key
|
||||
if sharkModel == 0:
|
||||
# max_new_tokens=512
|
||||
shark_slm = SharkStableLM(
|
||||
model_name
|
||||
@@ -277,6 +176,7 @@ def chat(
|
||||
|
||||
partial_text = ""
|
||||
for new_text in words_list:
|
||||
print(new_text)
|
||||
partial_text += new_text
|
||||
history[-1][1] = partial_text
|
||||
# Yield an empty string to clean up the message textbox and the updated
|
||||
@@ -398,7 +298,7 @@ with gr.Blocks(title="Chatbot") as stablelm_chat:
|
||||
)
|
||||
model = gr.Dropdown(
|
||||
label="Select Model",
|
||||
value=model_choices[4],
|
||||
value=model_choices[0],
|
||||
choices=model_choices,
|
||||
)
|
||||
supported_devices = available_devices
|
||||
@@ -406,35 +306,31 @@ with gr.Blocks(title="Chatbot") as stablelm_chat:
|
||||
# show cpu-task device first in list for chatbot
|
||||
supported_devices = supported_devices[-1:] + supported_devices[:-1]
|
||||
supported_devices = [x for x in supported_devices if "sync" not in x]
|
||||
# print(supported_devices)
|
||||
devices = gr.Dropdown(
|
||||
print(supported_devices)
|
||||
device = gr.Dropdown(
|
||||
label="Device",
|
||||
value=supported_devices[0]
|
||||
if enabled
|
||||
else "Only CUDA Supported for now",
|
||||
choices=supported_devices,
|
||||
interactive=enabled,
|
||||
# multiselect=True,
|
||||
)
|
||||
precision = gr.Radio(
|
||||
label="Precision",
|
||||
value="int8",
|
||||
value="fp16",
|
||||
choices=[
|
||||
"int4",
|
||||
"int8",
|
||||
"fp16",
|
||||
"fp32",
|
||||
],
|
||||
visible=True,
|
||||
)
|
||||
tokens_time = gr.Textbox(label="Tokens generated per second")
|
||||
|
||||
with gr.Row(visible=False):
|
||||
with gr.Row():
|
||||
with gr.Group():
|
||||
config_file = gr.File(
|
||||
label="Upload sharding configuration", visible=False
|
||||
)
|
||||
json_view_button = gr.Button(label="View as JSON", visible=False)
|
||||
json_view = gr.JSON(interactive=True, visible=False)
|
||||
config_file = gr.File(label="Upload sharding configuration")
|
||||
json_view_button = gr.Button("View as JSON")
|
||||
json_view = gr.JSON()
|
||||
json_view_button.click(
|
||||
fn=view_json_file, inputs=[config_file], outputs=[json_view]
|
||||
)
|
||||
@@ -461,16 +357,16 @@ with gr.Blocks(title="Chatbot") as stablelm_chat:
|
||||
fn=user, inputs=[msg, chatbot], outputs=[msg, chatbot], queue=False
|
||||
).then(
|
||||
fn=chat,
|
||||
inputs=[system_msg, chatbot, model, devices, precision, config_file],
|
||||
outputs=[chatbot, tokens_time],
|
||||
inputs=[system_msg, chatbot, model, device, precision],
|
||||
outputs=[chatbot],
|
||||
queue=True,
|
||||
)
|
||||
submit_click_event = submit.click(
|
||||
fn=user, inputs=[msg, chatbot], outputs=[msg, chatbot], queue=False
|
||||
).then(
|
||||
fn=chat,
|
||||
inputs=[system_msg, chatbot, model, devices, precision, config_file],
|
||||
outputs=[chatbot, tokens_time],
|
||||
inputs=[system_msg, chatbot, model, device, precision],
|
||||
outputs=[chatbot],
|
||||
queue=True,
|
||||
)
|
||||
stop.click(
|
||||
|
||||
@@ -24,13 +24,13 @@ def get_image(url, local_filename):
|
||||
shutil.copyfileobj(res.raw, f)
|
||||
|
||||
|
||||
def compare_images(new_filename, golden_filename, upload=False):
|
||||
def compare_images(new_filename, golden_filename):
|
||||
new = np.array(Image.open(new_filename)) / 255.0
|
||||
golden = np.array(Image.open(golden_filename)) / 255.0
|
||||
diff = np.abs(new - golden)
|
||||
mean = np.mean(diff)
|
||||
if mean > 0.1:
|
||||
if os.name != "nt" and upload == True:
|
||||
if os.name != "nt":
|
||||
subprocess.run(
|
||||
[
|
||||
"gsutil",
|
||||
@@ -39,7 +39,7 @@ def compare_images(new_filename, golden_filename, upload=False):
|
||||
"gs://shark_tank/testdata/builder/",
|
||||
]
|
||||
)
|
||||
raise AssertionError("new and golden not close")
|
||||
raise SystemExit("new and golden not close")
|
||||
else:
|
||||
print("SUCCESS")
|
||||
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
#!/bin/bash
|
||||
|
||||
IMPORTER=1 BENCHMARK=1 NO_BREVITAS=1 ./setup_venv.sh
|
||||
IMPORTER=1 BENCHMARK=1 ./setup_venv.sh
|
||||
source $GITHUB_WORKSPACE/shark.venv/bin/activate
|
||||
python build_tools/stable_diffusion_testing.py --gen
|
||||
python tank/generate_sharktank.py
|
||||
|
||||
@@ -63,14 +63,7 @@ def get_inpaint_inputs():
|
||||
open("./test_images/inputs/mask.png", "wb").write(mask.content)
|
||||
|
||||
|
||||
def test_loop(
|
||||
device="vulkan",
|
||||
beta=False,
|
||||
extra_flags=[],
|
||||
upload_bool=True,
|
||||
exit_on_fail=True,
|
||||
do_gen=False,
|
||||
):
|
||||
def test_loop(device="vulkan", beta=False, extra_flags=[]):
|
||||
# Get golden values from tank
|
||||
shutil.rmtree("./test_images", ignore_errors=True)
|
||||
model_metrics = []
|
||||
@@ -88,8 +81,6 @@ def test_loop(
|
||||
if beta:
|
||||
extra_flags.append("--beta_models=True")
|
||||
extra_flags.append("--no-progress_bar")
|
||||
if do_gen:
|
||||
extra_flags.append("--import_debug")
|
||||
to_skip = [
|
||||
"Linaqruf/anything-v3.0",
|
||||
"prompthero/openjourney",
|
||||
@@ -190,14 +181,7 @@ def test_loop(
|
||||
"./test_images/golden/" + model_name + "/*.png"
|
||||
)
|
||||
golden_file = glob(golden_path)[0]
|
||||
try:
|
||||
compare_images(
|
||||
test_file, golden_file, upload=upload_bool
|
||||
)
|
||||
except AssertionError as e:
|
||||
print(e)
|
||||
if exit_on_fail == True:
|
||||
raise
|
||||
compare_images(test_file, golden_file)
|
||||
else:
|
||||
print(command)
|
||||
print("failed to generate image for this configuration")
|
||||
@@ -216,9 +200,6 @@ def test_loop(
|
||||
extra_flags.remove(
|
||||
"--iree_vulkan_target_triple=rdna2-unknown-windows"
|
||||
)
|
||||
if do_gen:
|
||||
prepare_artifacts()
|
||||
|
||||
with open(os.path.join(os.getcwd(), "sd_testing_metrics.csv"), "w+") as f:
|
||||
header = "model_name;device;use_tune;import_opt;Clip Inference time(ms);Average Step (ms/it);VAE Inference time(ms);total image generation(s);command\n"
|
||||
f.write(header)
|
||||
@@ -237,49 +218,15 @@ def test_loop(
|
||||
f.write(";".join(output) + "\n")
|
||||
|
||||
|
||||
def prepare_artifacts():
|
||||
gen_path = os.path.join(os.getcwd(), "gen_shark_tank")
|
||||
if not os.path.isdir(gen_path):
|
||||
os.mkdir(gen_path)
|
||||
for dirname in os.listdir(os.getcwd()):
|
||||
for modelname in ["clip", "unet", "vae"]:
|
||||
if modelname in dirname and "vmfb" not in dirname:
|
||||
if not os.path.isdir(os.path.join(gen_path, dirname)):
|
||||
shutil.move(os.path.join(os.getcwd(), dirname), gen_path)
|
||||
print(f"Moved dir: {dirname} to {gen_path}.")
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument("-d", "--device", default="vulkan")
|
||||
parser.add_argument(
|
||||
"-b", "--beta", action=argparse.BooleanOptionalAction, default=False
|
||||
)
|
||||
parser.add_argument("-e", "--extra_args", type=str, default=None)
|
||||
parser.add_argument(
|
||||
"-u", "--upload", action=argparse.BooleanOptionalAction, default=True
|
||||
)
|
||||
parser.add_argument(
|
||||
"-x", "--exit_on_fail", action=argparse.BooleanOptionalAction, default=True
|
||||
)
|
||||
parser.add_argument(
|
||||
"-g", "--gen", action=argparse.BooleanOptionalAction, default=False
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parser.parse_args()
|
||||
print(args)
|
||||
extra_args = []
|
||||
if args.extra_args:
|
||||
for arg in args.extra_args.split(","):
|
||||
extra_args.append(arg)
|
||||
test_loop(
|
||||
args.device,
|
||||
args.beta,
|
||||
extra_args,
|
||||
args.upload,
|
||||
args.exit_on_fail,
|
||||
args.gen,
|
||||
)
|
||||
if args.gen:
|
||||
prepare_artifacts()
|
||||
test_loop(args.device, args.beta, [])
|
||||
|
||||
@@ -27,7 +27,7 @@ include(FetchContent)
|
||||
|
||||
FetchContent_Declare(
|
||||
iree
|
||||
GIT_REPOSITORY https://github.com/nod-ai/srt.git
|
||||
GIT_REPOSITORY https://github.com/nod-ai/shark-runtime.git
|
||||
GIT_TAG shark
|
||||
GIT_SUBMODULES_RECURSE OFF
|
||||
GIT_SHALLOW OFF
|
||||
|
||||
@@ -63,8 +63,8 @@ Where `${NUM}` is the dispatch number that you want to benchmark/profile in isol
|
||||
|
||||
### Enabling Tracy for Vulkan profiling
|
||||
|
||||
To begin profiling with Tracy, a build of IREE runtime with tracing enabled is needed. SHARK-Runtime (SRT) builds an
|
||||
instrumented version alongside the normal version nightly (.whls typically found [here](https://github.com/nod-ai/SRT/releases)), however this is only available for Linux. For Windows, tracing can be enabled by enabling a CMake flag.
|
||||
To begin profiling with Tracy, a build of IREE runtime with tracing enabled is needed. SHARK-Runtime builds an
|
||||
instrumented version alongside the normal version nightly (.whls typically found [here](https://github.com/nod-ai/SHARK-Runtime/releases)), however this is only available for Linux. For Windows, tracing can be enabled by enabling a CMake flag.
|
||||
```
|
||||
$env:IREE_ENABLE_RUNTIME_TRACING="ON"
|
||||
```
|
||||
|
||||
@@ -95,7 +95,7 @@ target_include_directories(
|
||||
|
||||
list(APPEND CMAKE_MODULE_PATH "${PROJECT_BINARY_DIR}/lib/cmake/mlir")
|
||||
|
||||
add_subdirectory(thirdparty/srt EXCLUDE_FROM_ALL)
|
||||
add_subdirectory(thirdparty/shark-runtime EXCLUDE_FROM_ALL)
|
||||
|
||||
target_link_libraries(triton-dshark-backend PRIVATE iree_base_base
|
||||
iree_hal_hal
|
||||
|
||||
@@ -22,7 +22,7 @@ git submodule update --init
|
||||
update the submodules of iree
|
||||
|
||||
```
|
||||
cd thirdparty/srt
|
||||
cd thirdparty/shark-runtime
|
||||
git submodule update --init
|
||||
```
|
||||
|
||||
|
||||
@@ -5,7 +5,7 @@ requires = [
|
||||
"packaging",
|
||||
|
||||
"numpy>=1.22.4",
|
||||
"torch-mlir>=20230620.875",
|
||||
"torch-mlir>=20221021.633",
|
||||
"iree-compiler>=20221022.190",
|
||||
"iree-runtime>=20221022.190",
|
||||
]
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
|
||||
numpy>1.22.4
|
||||
pytorch-triton
|
||||
torchvision
|
||||
torchvision==0.16.0.dev20230322
|
||||
tabulate
|
||||
|
||||
tqdm
|
||||
@@ -15,7 +15,7 @@ iree-tools-tf
|
||||
|
||||
# TensorFlow and JAX.
|
||||
gin-config
|
||||
tf-nightly
|
||||
tensorflow>2.11
|
||||
keras
|
||||
#tf-models-nightly
|
||||
#tensorflow-text-nightly
|
||||
|
||||
@@ -1,6 +1,3 @@
|
||||
-f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html
|
||||
--pre
|
||||
|
||||
setuptools
|
||||
wheel
|
||||
|
||||
@@ -27,8 +24,7 @@ ftfy
|
||||
gradio
|
||||
altair
|
||||
omegaconf
|
||||
# 0.3.2 doesn't have binaries for arm64
|
||||
safetensors==0.3.1
|
||||
safetensors
|
||||
opencv-python
|
||||
scikit-image
|
||||
pytorch_lightning # for runwayml models
|
||||
@@ -39,7 +35,6 @@ py-cpuinfo
|
||||
tiktoken # for codegen
|
||||
joblib # for langchain
|
||||
timm # for MiniGPT4
|
||||
langchain
|
||||
|
||||
# Keep PyInstaller at the end. Sometimes Windows Defender flags it but most folks can continue even if it errors
|
||||
pefile
|
||||
|
||||
@@ -90,8 +90,8 @@ python -m pip install --upgrade pip
|
||||
pip install wheel
|
||||
pip install -r requirements.txt
|
||||
pip install --pre torch-mlir torch --extra-index-url https://download.pytorch.org/whl/nightly/cpu -f https://llvm.github.io/torch-mlir/package-index/
|
||||
pip install --upgrade -f https://nod-ai.github.io/SRT/pip-release-links.html iree-compiler iree-runtime
|
||||
pip install --upgrade -f https://nod-ai.github.io/SHARK-Runtime/pip-release-links.html iree-compiler iree-runtime
|
||||
Write-Host "Building SHARK..."
|
||||
pip install -e . -f https://llvm.github.io/torch-mlir/package-index/ -f https://nod-ai.github.io/SRT/pip-release-links.html
|
||||
pip install -e . -f https://llvm.github.io/torch-mlir/package-index/ -f https://nod-ai.github.io/SHARK-Runtime/pip-release-links.html
|
||||
Write-Host "Build and installation completed successfully"
|
||||
Write-Host "Source your venv with ./shark.venv/Scripts/activate"
|
||||
|
||||
@@ -103,7 +103,7 @@ else
|
||||
fi
|
||||
if [[ -z "${USE_IREE}" ]]; then
|
||||
rm .use-iree
|
||||
RUNTIME="https://nod-ai.github.io/SRT/pip-release-links.html"
|
||||
RUNTIME="https://nod-ai.github.io/SHARK-Runtime/pip-release-links.html"
|
||||
else
|
||||
touch ./.use-iree
|
||||
RUNTIME="https://openxla.github.io/iree/pip-release-links.html"
|
||||
@@ -128,7 +128,7 @@ if [[ ! -z "${IMPORTER}" ]]; then
|
||||
fi
|
||||
fi
|
||||
|
||||
$PYTHON -m pip install --no-warn-conflicts -e . -f https://llvm.github.io/torch-mlir/package-index/ -f ${RUNTIME} -f https://download.pytorch.org/whl/nightly/cpu/
|
||||
$PYTHON -m pip install --no-warn-conflicts -e . -f https://llvm.github.io/torch-mlir/package-index/ -f ${RUNTIME} -f https://download.pytorch.org/whl/nightly/torch/
|
||||
|
||||
if [[ $(uname -s) = 'Linux' && ! -z "${BENCHMARK}" ]]; then
|
||||
T_VER=$($PYTHON -m pip show torch | grep Version)
|
||||
@@ -145,8 +145,14 @@ if [[ $(uname -s) = 'Linux' && ! -z "${BENCHMARK}" ]]; then
|
||||
fi
|
||||
fi
|
||||
|
||||
if [[ -z "${NO_BREVITAS}" ]]; then
|
||||
$PYTHON -m pip install git+https://github.com/Xilinx/brevitas.git@llm
|
||||
if [[ ! -z "${ONNX}" ]]; then
|
||||
echo "${Yellow}Installing ONNX and onnxruntime for benchmarks..."
|
||||
$PYTHON -m pip install onnx onnxruntime psutil
|
||||
if [ $? -eq 0 ];then
|
||||
echo "Successfully installed ONNX and ONNX runtime."
|
||||
else
|
||||
echo "Could not install ONNX." >&2
|
||||
fi
|
||||
fi
|
||||
|
||||
if [[ -z "${CONDA_PREFIX}" && "$SKIP_VENV" != "1" ]]; then
|
||||
|
||||
@@ -43,7 +43,9 @@ if __name__ == "__main__":
|
||||
minilm_mlir, func_name = mlir_importer.import_mlir(
|
||||
is_dynamic=False, tracing_required=True
|
||||
)
|
||||
shark_module = SharkInference(minilm_mlir)
|
||||
shark_module = SharkInference(
|
||||
minilm_mlir, func_name, mlir_dialect="linalg"
|
||||
)
|
||||
shark_module.compile()
|
||||
token_logits = torch.tensor(shark_module.forward(inputs))
|
||||
mask_id = torch.where(
|
||||
|
||||
325
shark/examples/shark_training/simple_dlrm_training.py
Normal file
325
shark/examples/shark_training/simple_dlrm_training.py
Normal file
@@ -0,0 +1,325 @@
|
||||
import torch
|
||||
from torch.nn.utils import stateless
|
||||
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
||||
from shark.shark_trainer import SharkTrainer
|
||||
import argparse
|
||||
import sys
|
||||
import numpy as np
|
||||
import torch.nn as nn
|
||||
from shark.shark_inference import SharkInference
|
||||
from shark.shark_importer import SharkImporter
|
||||
import torch_mlir
|
||||
import extend_distributed as ext_dist
|
||||
|
||||
### define dlrm in PyTorch ###
|
||||
class DLRM_Net(nn.Module):
|
||||
def create_mlp(self, ln, sigmoid_layer):
|
||||
# build MLP layer by layer
|
||||
layers = nn.ModuleList()
|
||||
for i in range(0, ln.size - 1):
|
||||
n = ln[i]
|
||||
m = ln[i + 1]
|
||||
|
||||
# construct fully connected operator
|
||||
LL = nn.Linear(int(n), int(m), bias=True)
|
||||
|
||||
# initialize the weights
|
||||
# with torch.no_grad():
|
||||
# custom Xavier input, output or two-sided fill
|
||||
|
||||
mean = 0.0 # std_dev = np.sqrt(variance)
|
||||
std_dev = np.sqrt(2 / (m + n)) # np.sqrt(1 / m) # np.sqrt(1 / n)
|
||||
W = np.random.normal(mean, std_dev, size=(m, n)).astype(np.float32)
|
||||
std_dev = np.sqrt(1 / m) # np.sqrt(2 / (m + 1))
|
||||
bt = np.random.normal(mean, std_dev, size=m).astype(np.float32)
|
||||
LL.weight.data = torch.tensor(W, requires_grad=True)
|
||||
LL.bias.data = torch.tensor(bt, requires_grad=True)
|
||||
|
||||
# approach 2
|
||||
# LL.weight.data.copy_(torch.tensor(W))
|
||||
# LL.bias.data.copy_(torch.tensor(bt))
|
||||
# approach 3
|
||||
# LL.weight = Parameter(torch.tensor(W),requires_grad=True)
|
||||
# LL.bias = Parameter(torch.tensor(bt),requires_grad=True)
|
||||
layers.append(LL)
|
||||
|
||||
# construct sigmoid or relu operator
|
||||
if i == sigmoid_layer:
|
||||
layers.append(nn.Sigmoid())
|
||||
else:
|
||||
layers.append(nn.ReLU())
|
||||
|
||||
# approach 1: use ModuleList
|
||||
# return layers
|
||||
# approach 2: use Sequential container to wrap all layers
|
||||
return torch.nn.Sequential(*layers)
|
||||
|
||||
def create_emb(self, m, ln, weighted_pooling=None):
|
||||
emb_l = nn.ModuleList()
|
||||
v_W_l = []
|
||||
for i in range(0, ln.size):
|
||||
n = ln[i]
|
||||
|
||||
# construct embedding operator
|
||||
EE = nn.EmbeddingBag(n, m, mode="sum")
|
||||
# initialize embeddings
|
||||
# nn.init.uniform_(EE.weight, a=-np.sqrt(1 / n), b=np.sqrt(1 / n))
|
||||
W = np.random.uniform(
|
||||
low=-np.sqrt(1 / n), high=np.sqrt(1 / n), size=(n, m)
|
||||
).astype(np.float32)
|
||||
# approach 1
|
||||
print(W)
|
||||
EE.weight.data = torch.tensor(W, requires_grad=True)
|
||||
# approach 2
|
||||
# EE.weight.data.copy_(torch.tensor(W))
|
||||
# approach 3
|
||||
# EE.weight = Parameter(torch.tensor(W),requires_grad=True)
|
||||
if weighted_pooling is None:
|
||||
v_W_l.append(None)
|
||||
else:
|
||||
v_W_l.append(torch.ones(n, dtype=torch.float32))
|
||||
emb_l.append(EE)
|
||||
return emb_l, v_W_l
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
m_spa=None,
|
||||
ln_emb=None,
|
||||
ln_bot=None,
|
||||
ln_top=None,
|
||||
arch_interaction_op=None,
|
||||
arch_interaction_itself=False,
|
||||
sigmoid_bot=-1,
|
||||
sigmoid_top=-1,
|
||||
weighted_pooling=None,
|
||||
):
|
||||
super(DLRM_Net, self).__init__()
|
||||
|
||||
if (
|
||||
(m_spa is not None)
|
||||
and (ln_emb is not None)
|
||||
and (ln_bot is not None)
|
||||
and (ln_top is not None)
|
||||
and (arch_interaction_op is not None)
|
||||
):
|
||||
# save arguments
|
||||
self.output_d = 0
|
||||
self.arch_interaction_op = arch_interaction_op
|
||||
self.arch_interaction_itself = arch_interaction_itself
|
||||
if weighted_pooling is not None and weighted_pooling != "fixed":
|
||||
self.weighted_pooling = "learned"
|
||||
else:
|
||||
self.weighted_pooling = weighted_pooling
|
||||
|
||||
# create operators
|
||||
self.emb_l, w_list = self.create_emb(
|
||||
m_spa, ln_emb, weighted_pooling
|
||||
)
|
||||
if self.weighted_pooling == "learned":
|
||||
self.v_W_l = nn.ParameterList()
|
||||
for w in w_list:
|
||||
self.v_W_l.append(nn.Parameter(w))
|
||||
else:
|
||||
self.v_W_l = w_list
|
||||
self.bot_l = self.create_mlp(ln_bot, sigmoid_bot)
|
||||
self.top_l = self.create_mlp(ln_top, sigmoid_top)
|
||||
|
||||
def apply_mlp(self, x, layers):
|
||||
return layers(x)
|
||||
|
||||
def apply_emb(self, lS_o, lS_i, emb_l, v_W_l):
|
||||
# WARNING: notice that we are processing the batch at once. We implicitly
|
||||
# assume that the data is laid out such that:
|
||||
# 1. each embedding is indexed with a group of sparse indices,
|
||||
# corresponding to a single lookup
|
||||
# 2. for each embedding the lookups are further organized into a batch
|
||||
# 3. for a list of embedding tables there is a list of batched lookups
|
||||
# TORCH-MLIR
|
||||
# We are passing all the embeddings as arguments for easy parsing.
|
||||
|
||||
ly = []
|
||||
for k, sparse_index_group_batch in enumerate(lS_i):
|
||||
sparse_offset_group_batch = lS_o[k]
|
||||
|
||||
# embedding lookup
|
||||
# We are using EmbeddingBag, which implicitly uses sum operator.
|
||||
# The embeddings are represented as tall matrices, with sum
|
||||
# happening vertically across 0 axis, resulting in a row vector
|
||||
# E = emb_l[k]
|
||||
|
||||
if v_W_l[k] is not None:
|
||||
per_sample_weights = v_W_l[k].gather(
|
||||
0, sparse_index_group_batch
|
||||
)
|
||||
else:
|
||||
per_sample_weights = None
|
||||
|
||||
E = emb_l[k]
|
||||
V = E(
|
||||
sparse_index_group_batch,
|
||||
sparse_offset_group_batch,
|
||||
per_sample_weights=per_sample_weights,
|
||||
)
|
||||
|
||||
ly.append(V)
|
||||
|
||||
return ly
|
||||
|
||||
def interact_features(self, x, ly):
|
||||
if self.arch_interaction_op == "dot":
|
||||
# concatenate dense and sparse features
|
||||
(batch_size, d) = x.shape
|
||||
T = torch.cat([x] + ly, dim=1).view((batch_size, -1, d))
|
||||
# perform a dot product
|
||||
Z = torch.bmm(T, torch.transpose(T, 1, 2))
|
||||
# append dense feature with the interactions (into a row vector)
|
||||
# approach 1: all
|
||||
# Zflat = Z.view((batch_size, -1))
|
||||
# approach 2: unique
|
||||
_, ni, nj = Z.shape
|
||||
# approach 1: tril_indices
|
||||
# offset = 0 if self.arch_interaction_itself else -1
|
||||
# li, lj = torch.tril_indices(ni, nj, offset=offset)
|
||||
# approach 2: custom
|
||||
offset = 1 if self.arch_interaction_itself else 0
|
||||
li = torch.tensor(
|
||||
[i for i in range(ni) for j in range(i + offset)]
|
||||
)
|
||||
lj = torch.tensor(
|
||||
[j for i in range(nj) for j in range(i + offset)]
|
||||
)
|
||||
Zflat = Z[:, li, lj]
|
||||
# concatenate dense features and interactions
|
||||
R = torch.cat([x] + [Zflat], dim=1)
|
||||
elif self.arch_interaction_op == "cat":
|
||||
# concatenation features (into a row vector)
|
||||
R = torch.cat([x] + ly, dim=1)
|
||||
else:
|
||||
sys.exit(
|
||||
"ERROR: --arch-interaction-op="
|
||||
+ self.arch_interaction_op
|
||||
+ " is not supported"
|
||||
)
|
||||
|
||||
return R
|
||||
|
||||
def forward(self, dense_x, lS_o, *lS_i):
|
||||
return self.sequential_forward(dense_x, lS_o, lS_i)
|
||||
|
||||
def sequential_forward(self, dense_x, lS_o, lS_i):
|
||||
# process dense features (using bottom mlp), resulting in a row vector
|
||||
x = self.apply_mlp(dense_x, self.bot_l)
|
||||
# debug prints
|
||||
# print("intermediate")
|
||||
# print(x.detach().cpu().numpy())
|
||||
|
||||
# process sparse features(using embeddings), resulting in a list of row vectors
|
||||
ly = self.apply_emb(lS_o, lS_i, self.emb_l, self.v_W_l)
|
||||
# for y in ly:
|
||||
# print(y.detach().cpu().numpy())
|
||||
|
||||
# interact features (dense and sparse)
|
||||
z = self.interact_features(x, ly)
|
||||
# print(z.detach().cpu().numpy())
|
||||
|
||||
# obtain probability of a click (using top mlp)
|
||||
p = self.apply_mlp(z, self.top_l)
|
||||
|
||||
# # clamp output if needed
|
||||
# if 0.0 < self.loss_threshold and self.loss_threshold < 1.0:
|
||||
# z = torch.clamp(p, min=self.loss_threshold, max=(1.0 - self.loss_threshold))
|
||||
# else:
|
||||
# z = p
|
||||
|
||||
return p
|
||||
|
||||
def dash_separated_ints(value):
|
||||
vals = value.split("-")
|
||||
for val in vals:
|
||||
try:
|
||||
int(val)
|
||||
except ValueError:
|
||||
raise argparse.ArgumentTypeError(
|
||||
"%s is not a valid dash separated list of ints" % value
|
||||
)
|
||||
|
||||
return value
|
||||
|
||||
|
||||
# model related parameters
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Train Deep Learning Recommendation Model (DLRM)"
|
||||
)
|
||||
parser.add_argument("--arch-sparse-feature-size", type=int, default=2)
|
||||
parser.add_argument(
|
||||
"--arch-embedding-size", type=dash_separated_ints, default="4-3-2"
|
||||
)
|
||||
# j will be replaced with the table number
|
||||
parser.add_argument(
|
||||
"--arch-mlp-bot", type=dash_separated_ints, default="4-3-2"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--arch-mlp-top", type=dash_separated_ints, default="8-2-1"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--arch-interaction-op", type=str, choices=["dot", "cat"], default="dot"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--arch-interaction-itself", action="store_true", default=False
|
||||
)
|
||||
parser.add_argument("--weighted-pooling", type=str, default=None)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
ln_bot = np.fromstring(args.arch_mlp_bot, dtype=int, sep="-")
|
||||
ln_top = np.fromstring(args.arch_mlp_top, dtype=int, sep="-")
|
||||
m_den = ln_bot[0]
|
||||
ln_emb = np.fromstring(args.arch_embedding_size, dtype=int, sep="-")
|
||||
m_spa = args.arch_sparse_feature_size
|
||||
ln_emb = np.asarray(ln_emb)
|
||||
num_fea = ln_emb.size + 1 # num sparse + num dense features
|
||||
|
||||
|
||||
# Initialize the model.
|
||||
dlrm_model = DLRM_Net(
|
||||
m_spa=m_spa,
|
||||
ln_emb=ln_emb,
|
||||
ln_bot=ln_bot,
|
||||
ln_top=ln_top,
|
||||
arch_interaction_op=args.arch_interaction_op,
|
||||
)
|
||||
|
||||
def get_sorted_params(named_params):
|
||||
return [i[1] for i in sorted(named_params.items())]
|
||||
|
||||
dense_inp = torch.tensor([[0.6965, 0.2861, 0.2269, 0.5513]])
|
||||
vs0 = torch.tensor([[0], [0], [0]], dtype=torch.int64)
|
||||
vsi = torch.tensor([1, 2, 3]), torch.tensor([1]), torch.tensor([1])
|
||||
|
||||
input_dlrm = (dense_inp, vs0, *vsi)
|
||||
|
||||
mlir_importer = SharkImporter(
|
||||
dlrm_model,
|
||||
input_dlrm,
|
||||
frontend="torch",
|
||||
)
|
||||
|
||||
dlrm_mlir = torch_mlir.compile(dlrm_model, input_dlrm, torch_mlir.OutputType.LINALG_ON_TENSORS, use_tracing=True)
|
||||
print(dlrm_mlir)
|
||||
|
||||
def forward(params, buffers, args):
|
||||
params_and_buffers = {**params, **buffers}
|
||||
stateless.functional_call(
|
||||
dlrm_model, params_and_buffers, args, {}
|
||||
).sum().backward()
|
||||
optim = torch.optim.SGD(get_sorted_params(params), lr=0.01)
|
||||
# optim.load_state_dict(optim_state)
|
||||
optim.step()
|
||||
return params, buffers
|
||||
|
||||
shark_module = SharkTrainer(dlrm_model, input_dlrm)
|
||||
print("________________________________________________________________________")
|
||||
shark_module.compile(forward)
|
||||
print("________________________________________________________________________")
|
||||
shark_module.train(num_iters=2)
|
||||
print("training done")
|
||||
@@ -13,7 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
## Common utilities to be shared by iree utilities.
|
||||
import functools
|
||||
|
||||
import os
|
||||
import sys
|
||||
import subprocess
|
||||
@@ -93,7 +93,6 @@ _IREE_TARGET_MAP = {
|
||||
|
||||
|
||||
# Finds whether the required drivers are installed for the given device.
|
||||
@functools.cache
|
||||
def check_device_drivers(device):
|
||||
"""Checks necessary drivers present for gpu and vulkan devices"""
|
||||
if "://" in device:
|
||||
|
||||
@@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import iree._runtime.scripts.iree_benchmark_module as benchmark_module
|
||||
import iree.runtime.scripts.iree_benchmark_module as benchmark_module
|
||||
from shark.iree_utils._common import run_cmd, iree_device_map
|
||||
from shark.iree_utils.cpu_utils import get_cpu_count
|
||||
import numpy as np
|
||||
@@ -62,12 +62,16 @@ def build_benchmark_args(
|
||||
and whether it is training or not.
|
||||
Outputs: string that execute benchmark-module on target model.
|
||||
"""
|
||||
path = os.path.join(os.environ["VIRTUAL_ENV"], "bin")
|
||||
path = benchmark_module.__path__[0]
|
||||
if platform.system() == "Windows":
|
||||
benchmarker_path = os.path.join(path, "iree-benchmark-module.exe")
|
||||
benchmarker_path = os.path.join(
|
||||
path, "..", "..", "iree-benchmark-module.exe"
|
||||
)
|
||||
time_extractor = None
|
||||
else:
|
||||
benchmarker_path = os.path.join(path, "iree-benchmark-module")
|
||||
benchmarker_path = os.path.join(
|
||||
path, "..", "..", "iree-benchmark-module"
|
||||
)
|
||||
time_extractor = "| awk 'END{{print $2 $3}}'"
|
||||
benchmark_cl = [benchmarker_path, f"--module={input_file}"]
|
||||
# TODO: The function named can be passed as one of the args.
|
||||
|
||||
@@ -11,23 +11,18 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import functools
|
||||
import iree.runtime as ireert
|
||||
import iree.compiler as ireec
|
||||
from shark.iree_utils._common import iree_device_map, iree_target_map
|
||||
from shark.iree_utils.cpu_utils import get_iree_cpu_rt_args
|
||||
from shark.iree_utils.benchmark_utils import *
|
||||
from shark.parser import shark_args
|
||||
import numpy as np
|
||||
import os
|
||||
import re
|
||||
import tempfile
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import iree.runtime as ireert
|
||||
import iree.compiler as ireec
|
||||
from shark.parser import shark_args
|
||||
|
||||
from .trace import DetailLogger
|
||||
from ._common import iree_device_map, iree_target_map
|
||||
from .cpu_utils import get_iree_cpu_rt_args
|
||||
from .benchmark_utils import *
|
||||
|
||||
|
||||
# Get the iree-compile arguments given device.
|
||||
def get_iree_device_args(device, extra_args=[]):
|
||||
@@ -95,7 +90,6 @@ def get_iree_frontend_args(frontend):
|
||||
def get_iree_common_args():
|
||||
return [
|
||||
"--iree-stream-resource-index-bits=64",
|
||||
"--iree-stream-resource-max-allocation-size=4294967295",
|
||||
"--iree-vm-target-index-bits=64",
|
||||
"--iree-vm-bytecode-module-strip-source-map=true",
|
||||
"--iree-util-zero-fill-elided-attrs",
|
||||
@@ -323,6 +317,7 @@ def get_iree_module(flatbuffer_blob, device, device_idx=None):
|
||||
device = iree_device_map(device)
|
||||
print("registering device id: ", device_idx)
|
||||
haldriver = ireert.get_driver(device)
|
||||
|
||||
haldevice = haldriver.create_device(
|
||||
haldriver.query_available_devices()[device_idx]["device_id"],
|
||||
allocators=shark_args.device_allocator,
|
||||
@@ -342,64 +337,58 @@ def get_iree_module(flatbuffer_blob, device, device_idx=None):
|
||||
def load_vmfb_using_mmap(
|
||||
flatbuffer_blob_or_path, device: str, device_idx: int = None
|
||||
):
|
||||
print(f"Loading module {flatbuffer_blob_or_path}...")
|
||||
instance = ireert.VmInstance()
|
||||
device = iree_device_map(device)
|
||||
haldriver = ireert.get_driver(device)
|
||||
haldevice = haldriver.create_device_by_uri(
|
||||
device,
|
||||
allocators=[],
|
||||
)
|
||||
# First get configs.
|
||||
if device_idx is not None:
|
||||
device = iree_device_map(device)
|
||||
print("registering device id: ", device_idx)
|
||||
haldriver = ireert.get_driver(device)
|
||||
|
||||
with DetailLogger(timeout=2.5) as dl:
|
||||
# First get configs.
|
||||
if device_idx is not None:
|
||||
dl.log(f"Mapping device id: {device_idx}")
|
||||
device = iree_device_map(device)
|
||||
haldriver = ireert.get_driver(device)
|
||||
dl.log(f"ireert.get_driver()")
|
||||
|
||||
haldevice = haldriver.create_device(
|
||||
haldriver.query_available_devices()[device_idx]["device_id"],
|
||||
allocators=shark_args.device_allocator,
|
||||
)
|
||||
dl.log(f"ireert.create_device()")
|
||||
config = ireert.Config(device=haldevice)
|
||||
dl.log(f"ireert.Config()")
|
||||
else:
|
||||
config = get_iree_runtime_config(device)
|
||||
dl.log("get_iree_runtime_config")
|
||||
if "task" in device:
|
||||
print(
|
||||
f"[DEBUG] setting iree runtime flags for cpu:\n{' '.join(get_iree_cpu_rt_args())}"
|
||||
)
|
||||
for flag in get_iree_cpu_rt_args():
|
||||
ireert.flags.parse_flags(flag)
|
||||
# Now load vmfb.
|
||||
# Two scenarios we have here :-
|
||||
# 1. We either have the vmfb already saved and therefore pass the path of it.
|
||||
# (This would arise if we're invoking `load_module` from a SharkInference obj)
|
||||
# OR 2. We are compiling on the fly, therefore we have the flatbuffer blob to play with.
|
||||
# (This would arise if we're invoking `compile` from a SharkInference obj)
|
||||
temp_file_to_unlink = None
|
||||
if isinstance(flatbuffer_blob_or_path, Path):
|
||||
flatbuffer_blob_or_path = flatbuffer_blob_or_path.__str__()
|
||||
if (
|
||||
isinstance(flatbuffer_blob_or_path, str)
|
||||
and ".vmfb" in flatbuffer_blob_or_path
|
||||
):
|
||||
vmfb_file_path = flatbuffer_blob_or_path
|
||||
mmaped_vmfb = ireert.VmModule.mmap(
|
||||
config.vm_instance, flatbuffer_blob_or_path
|
||||
)
|
||||
dl.log(f"mmap {flatbuffer_blob_or_path}")
|
||||
ctx = ireert.SystemContext(config=config)
|
||||
dl.log(f"ireert.SystemContext created")
|
||||
ctx.add_vm_module(mmaped_vmfb)
|
||||
dl.log(f"module initialized")
|
||||
mmaped_vmfb = getattr(ctx.modules, mmaped_vmfb.name)
|
||||
else:
|
||||
with tempfile.NamedTemporaryFile(delete=False) as tf:
|
||||
tf.write(flatbuffer_blob_or_path)
|
||||
tf.flush()
|
||||
vmfb_file_path = tf.name
|
||||
temp_file_to_unlink = vmfb_file_path
|
||||
mmaped_vmfb = ireert.VmModule.mmap(instance, vmfb_file_path)
|
||||
dl.log(f"mmap temp {vmfb_file_path}")
|
||||
return mmaped_vmfb, config, temp_file_to_unlink
|
||||
haldevice = haldriver.create_device(
|
||||
haldriver.query_available_devices()[device_idx]["device_id"],
|
||||
allocators=shark_args.device_allocator,
|
||||
)
|
||||
config = ireert.Config(device=haldevice)
|
||||
else:
|
||||
config = get_iree_runtime_config(device)
|
||||
if "task" in device:
|
||||
print(
|
||||
f"[DEBUG] setting iree runtime flags for cpu:\n{' '.join(get_iree_cpu_rt_args())}"
|
||||
)
|
||||
for flag in get_iree_cpu_rt_args():
|
||||
ireert.flags.parse_flags(flag)
|
||||
# Now load vmfb.
|
||||
# Two scenarios we have here :-
|
||||
# 1. We either have the vmfb already saved and therefore pass the path of it.
|
||||
# (This would arise if we're invoking `load_module` from a SharkInference obj)
|
||||
# OR 2. We are compiling on the fly, therefore we have the flatbuffer blob to play with.
|
||||
# (This would arise if we're invoking `compile` from a SharkInference obj)
|
||||
temp_file_to_unlink = None
|
||||
if isinstance(flatbuffer_blob_or_path, Path):
|
||||
flatbuffer_blob_or_path = flatbuffer_blob_or_path.__str__()
|
||||
if (
|
||||
isinstance(flatbuffer_blob_or_path, str)
|
||||
and ".vmfb" in flatbuffer_blob_or_path
|
||||
):
|
||||
vmfb_file_path = flatbuffer_blob_or_path
|
||||
mmaped_vmfb = ireert.VmModule.mmap(instance, flatbuffer_blob_or_path)
|
||||
ctx = ireert.SystemContext(config=config)
|
||||
ctx.add_vm_module(mmaped_vmfb)
|
||||
mmaped_vmfb = getattr(ctx.modules, mmaped_vmfb.name)
|
||||
else:
|
||||
with tempfile.NamedTemporaryFile(delete=False) as tf:
|
||||
tf.write(flatbuffer_blob_or_path)
|
||||
tf.flush()
|
||||
vmfb_file_path = tf.name
|
||||
temp_file_to_unlink = vmfb_file_path
|
||||
mmaped_vmfb = ireert.VmModule.mmap(instance, vmfb_file_path)
|
||||
return mmaped_vmfb, config, temp_file_to_unlink
|
||||
|
||||
|
||||
def get_iree_compiled_module(
|
||||
@@ -421,6 +410,7 @@ def get_iree_compiled_module(
|
||||
# we're setting delete=False when creating NamedTemporaryFile. That's why
|
||||
# I'm getting hold of the name of the temporary file in `temp_file_to_unlink`.
|
||||
if mmap:
|
||||
print(f"Will load the compiled module as a mmapped temporary file")
|
||||
vmfb, config, temp_file_to_unlink = load_vmfb_using_mmap(
|
||||
flatbuffer_blob, device, device_idx
|
||||
)
|
||||
@@ -444,6 +434,7 @@ def load_flatbuffer(
|
||||
):
|
||||
temp_file_to_unlink = None
|
||||
if mmap:
|
||||
print(f"Loading flatbuffer at {flatbuffer_path} as a mmapped file")
|
||||
vmfb, config, temp_file_to_unlink = load_vmfb_using_mmap(
|
||||
flatbuffer_path, device, device_idx
|
||||
)
|
||||
@@ -507,56 +498,37 @@ def get_results(
|
||||
config,
|
||||
frontend="torch",
|
||||
send_to_host=True,
|
||||
debug_timeout: float = 5.0,
|
||||
):
|
||||
"""Runs a .vmfb file given inputs and config and returns output."""
|
||||
with DetailLogger(debug_timeout) as dl:
|
||||
device_inputs = []
|
||||
for input_array in input:
|
||||
dl.log(f"Load to device: {input_array.shape}")
|
||||
device_inputs.append(
|
||||
ireert.asdevicearray(config.device, input_array)
|
||||
)
|
||||
dl.log(f"Invoke function: {function_name}")
|
||||
result = compiled_vm[function_name](*device_inputs)
|
||||
dl.log(f"Invoke complete")
|
||||
result_tensors = []
|
||||
if isinstance(result, tuple):
|
||||
if send_to_host:
|
||||
for val in result:
|
||||
dl.log(f"Result to host: {val.shape}")
|
||||
result_tensors.append(np.asarray(val, val.dtype))
|
||||
else:
|
||||
for val in result:
|
||||
result_tensors.append(val)
|
||||
return result_tensors
|
||||
elif isinstance(result, dict):
|
||||
data = list(result.items())
|
||||
if send_to_host:
|
||||
res = np.array(data, dtype=object)
|
||||
return np.copy(res)
|
||||
return data
|
||||
device_inputs = [ireert.asdevicearray(config.device, a) for a in input]
|
||||
result = compiled_vm[function_name](*device_inputs)
|
||||
result_tensors = []
|
||||
if isinstance(result, tuple):
|
||||
if send_to_host:
|
||||
for val in result:
|
||||
result_tensors.append(np.asarray(val, val.dtype))
|
||||
else:
|
||||
if send_to_host and result is not None:
|
||||
dl.log("Result to host")
|
||||
return result.to_host()
|
||||
return result
|
||||
dl.log("Execution complete")
|
||||
for val in result:
|
||||
result_tensors.append(val)
|
||||
return result_tensors
|
||||
elif isinstance(result, dict):
|
||||
data = list(result.items())
|
||||
if send_to_host:
|
||||
res = np.array(data, dtype=object)
|
||||
return np.copy(res)
|
||||
return data
|
||||
else:
|
||||
if send_to_host and result is not None:
|
||||
return result.to_host()
|
||||
return result
|
||||
|
||||
|
||||
@functools.cache
|
||||
def get_iree_runtime_config(device):
|
||||
device = iree_device_map(device)
|
||||
haldriver = ireert.get_driver(device)
|
||||
if device == "metal" and shark_args.device_allocator == "caching":
|
||||
print(
|
||||
"[WARNING] metal devices can not have a `caching` allocator."
|
||||
"\nUsing default allocator `None`"
|
||||
)
|
||||
haldevice = haldriver.create_device_by_uri(
|
||||
device,
|
||||
# metal devices have a failure with caching allocators atm. blcking this util it gets fixed upstream.
|
||||
allocators=shark_args.device_allocator if device != "metal" else None,
|
||||
allocators=shark_args.device_allocator,
|
||||
)
|
||||
config = ireert.Config(device=haldevice)
|
||||
return config
|
||||
|
||||
@@ -14,7 +14,6 @@
|
||||
|
||||
# All the iree_cpu related functionalities go here.
|
||||
|
||||
import functools
|
||||
import subprocess
|
||||
import platform
|
||||
from shark.parser import shark_args
|
||||
@@ -31,7 +30,6 @@ def get_cpu_count():
|
||||
|
||||
|
||||
# Get the default cpu args.
|
||||
@functools.cache
|
||||
def get_iree_cpu_args():
|
||||
uname = platform.uname()
|
||||
os_name, proc_name = uname.system, uname.machine
|
||||
@@ -53,7 +51,6 @@ def get_iree_cpu_args():
|
||||
|
||||
|
||||
# Get iree runtime flags for cpu
|
||||
@functools.cache
|
||||
def get_iree_cpu_rt_args():
|
||||
default = get_cpu_count()
|
||||
default = default if default <= 8 else default - 2
|
||||
|
||||
@@ -14,14 +14,12 @@
|
||||
|
||||
# All the iree_gpu related functionalities go here.
|
||||
|
||||
import functools
|
||||
import iree.runtime as ireert
|
||||
import ctypes
|
||||
from shark.parser import shark_args
|
||||
|
||||
|
||||
# Get the default gpu args given the architecture.
|
||||
@functools.cache
|
||||
def get_iree_gpu_args():
|
||||
ireert.flags.FUNCTION_INPUT_VALIDATION = False
|
||||
ireert.flags.parse_flags("--cuda_allow_inline_execution")
|
||||
@@ -39,7 +37,6 @@ def get_iree_gpu_args():
|
||||
|
||||
|
||||
# Get the default gpu args given the architecture.
|
||||
@functools.cache
|
||||
def get_iree_rocm_args():
|
||||
ireert.flags.FUNCTION_INPUT_VALIDATION = False
|
||||
# get arch from rocminfo.
|
||||
@@ -68,7 +65,6 @@ CU_DEVICE_ATTRIBUTE_CLOCK_RATE = 13
|
||||
CU_DEVICE_ATTRIBUTE_MEMORY_CLOCK_RATE = 36
|
||||
|
||||
|
||||
@functools.cache
|
||||
def get_cuda_sm_cc():
|
||||
libnames = ("libcuda.so", "libcuda.dylib", "nvcuda.dll")
|
||||
for libname in libnames:
|
||||
|
||||
@@ -14,15 +14,12 @@
|
||||
|
||||
# All the iree_vulkan related functionalities go here.
|
||||
|
||||
import functools
|
||||
|
||||
from shark.iree_utils._common import run_cmd
|
||||
import iree.runtime as ireert
|
||||
from sys import platform
|
||||
from shark.iree_utils.vulkan_target_env_utils import get_vulkan_target_env_flag
|
||||
|
||||
|
||||
@functools.cache
|
||||
def get_metal_device_name(device_num=0):
|
||||
iree_device_dump = run_cmd("iree-run-module --dump_devices")
|
||||
iree_device_dump = iree_device_dump[0].split("\n\n")
|
||||
|
||||
@@ -1,76 +0,0 @@
|
||||
# Copyright 2023 The Nod Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import List, Tuple
|
||||
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
|
||||
|
||||
def _enable_detail_trace() -> bool:
|
||||
return os.getenv("SHARK_DETAIL_TRACE", "0") == "1"
|
||||
|
||||
|
||||
class DetailLogger:
|
||||
"""Context manager which can accumulate detailed log messages.
|
||||
|
||||
Detailed log is only emitted if the operation takes a long time
|
||||
or errors.
|
||||
"""
|
||||
|
||||
def __init__(self, timeout: float):
|
||||
self._timeout = timeout
|
||||
self._messages: List[Tuple[float, str]] = []
|
||||
self._start_time = time.time()
|
||||
self._active = not _enable_detail_trace()
|
||||
self._lock = threading.RLock()
|
||||
self._cond = threading.Condition(self._lock)
|
||||
self._thread = None
|
||||
|
||||
def __enter__(self):
|
||||
self._thread = threading.Thread(target=self._run)
|
||||
self._thread.start()
|
||||
return self
|
||||
|
||||
def __exit__(self, type, value, traceback):
|
||||
with self._lock:
|
||||
self._active = False
|
||||
self._cond.notify()
|
||||
if traceback:
|
||||
self.dump_on_error(f"exception")
|
||||
|
||||
def _run(self):
|
||||
with self._lock:
|
||||
timed_out = not self._cond.wait(self._timeout)
|
||||
if timed_out:
|
||||
self.dump_on_error(f"took longer than {self._timeout}s")
|
||||
|
||||
def log(self, msg):
|
||||
with self._lock:
|
||||
timestamp = time.time()
|
||||
if self._active:
|
||||
self._messages.append((timestamp, msg))
|
||||
else:
|
||||
print(f" +{(timestamp - self._start_time) * 1000}ms: {msg}")
|
||||
|
||||
def dump_on_error(self, summary: str):
|
||||
with self._lock:
|
||||
if self._active:
|
||||
print(f"::: Detailed report ({summary}):")
|
||||
for timestamp, msg in self._messages:
|
||||
print(
|
||||
f" +{(timestamp - self._start_time) * 1000}ms: {msg}"
|
||||
)
|
||||
self._active = False
|
||||
@@ -13,10 +13,8 @@
|
||||
# limitations under the License.
|
||||
|
||||
from collections import OrderedDict
|
||||
import functools
|
||||
|
||||
|
||||
@functools.cache
|
||||
def get_vulkan_target_env(vulkan_target_triple):
|
||||
arch, product, os = vulkan_target_triple.split("=")[1].split("-")
|
||||
triple = (arch, product, os)
|
||||
@@ -54,7 +52,6 @@ def get_version(triple):
|
||||
return "v1.3"
|
||||
|
||||
|
||||
@functools.cache
|
||||
def get_extensions(triple):
|
||||
def make_ext_list(ext_list):
|
||||
res = ""
|
||||
@@ -125,7 +122,6 @@ def get_extensions(triple):
|
||||
return make_ext_list(ext_list=ext)
|
||||
|
||||
|
||||
@functools.cache
|
||||
def get_vendor(triple):
|
||||
arch, product, os = triple
|
||||
if arch == "unknown":
|
||||
@@ -150,7 +146,6 @@ def get_vendor(triple):
|
||||
return "Unknown"
|
||||
|
||||
|
||||
@functools.cache
|
||||
def get_device_type(triple):
|
||||
arch, product, _ = triple
|
||||
if arch == "unknown":
|
||||
@@ -171,7 +166,6 @@ def get_device_type(triple):
|
||||
|
||||
# get all the capabilities for the device
|
||||
# TODO: make a dataclass for capabilites and init using vulkaninfo
|
||||
@functools.cache
|
||||
def get_vulkan_target_capabilities(triple):
|
||||
def get_subgroup_val(l):
|
||||
return int(sum([subgroup_feature[sgf] for sgf in l]))
|
||||
|
||||
@@ -14,7 +14,6 @@
|
||||
|
||||
# All the iree_vulkan related functionalities go here.
|
||||
|
||||
import functools
|
||||
from os import linesep
|
||||
from shark.iree_utils._common import run_cmd
|
||||
import iree.runtime as ireert
|
||||
@@ -23,7 +22,6 @@ from shark.iree_utils.vulkan_target_env_utils import get_vulkan_target_env_flag
|
||||
from shark.parser import shark_args
|
||||
|
||||
|
||||
@functools.cache
|
||||
def get_vulkan_device_name(device_num=0):
|
||||
vulkaninfo_dump, _ = run_cmd("vulkaninfo")
|
||||
vulkaninfo_dump = vulkaninfo_dump.split(linesep)
|
||||
@@ -50,7 +48,6 @@ def get_os_name():
|
||||
return "linux"
|
||||
|
||||
|
||||
@functools.cache
|
||||
def get_vulkan_target_triple(device_name):
|
||||
"""This method provides a target triple str for specified vulkan device.
|
||||
|
||||
@@ -175,7 +172,6 @@ def get_iree_vulkan_args(device_num=0, extra_args=[]):
|
||||
return res_vulkan_flag
|
||||
|
||||
|
||||
@functools.cache
|
||||
def get_iree_vulkan_runtime_flags():
|
||||
vulkan_runtime_flags = [
|
||||
f"--vulkan_large_heap_block_size={shark_args.vulkan_large_heap_block_size}",
|
||||
|
||||
@@ -114,7 +114,7 @@ parser.add_argument(
|
||||
"--device_allocator",
|
||||
type=str,
|
||||
nargs="*",
|
||||
default=["caching"],
|
||||
default=[],
|
||||
help="Specifies one or more HAL device allocator specs "
|
||||
"to augment the base device allocator",
|
||||
choices=["debug", "caching"],
|
||||
@@ -149,7 +149,7 @@ parser.add_argument(
|
||||
|
||||
parser.add_argument(
|
||||
"--vulkan_vma_allocator",
|
||||
default=False,
|
||||
default=True,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="Flag for enabling / disabling Vulkan VMA Allocator.",
|
||||
)
|
||||
|
||||
@@ -13,11 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
from shark.shark_runner import SharkRunner
|
||||
from shark.iree_utils.compile_utils import (
|
||||
export_iree_module_to_vmfb,
|
||||
load_flatbuffer,
|
||||
get_iree_runtime_config,
|
||||
)
|
||||
from shark.iree_utils.compile_utils import export_iree_module_to_vmfb
|
||||
from shark.iree_utils.benchmark_utils import (
|
||||
build_benchmark_args,
|
||||
run_benchmark_module,
|
||||
@@ -83,31 +79,22 @@ class SharkBenchmarkRunner(SharkRunner):
|
||||
self.mlir_dialect = mlir_dialect
|
||||
self.extra_args = extra_args
|
||||
self.import_args = {}
|
||||
self.temp_file_to_unlink = None
|
||||
SharkRunner.__init__(
|
||||
self,
|
||||
mlir_module,
|
||||
device,
|
||||
self.mlir_dialect,
|
||||
self.extra_args,
|
||||
compile_vmfb=False,
|
||||
compile_vmfb=True,
|
||||
)
|
||||
self.vmfb_file = export_iree_module_to_vmfb(
|
||||
mlir_module,
|
||||
device,
|
||||
".",
|
||||
self.mlir_dialect,
|
||||
extra_args=self.extra_args,
|
||||
)
|
||||
params = load_flatbuffer(
|
||||
self.vmfb_file,
|
||||
device,
|
||||
mmap=True,
|
||||
)
|
||||
self.iree_compilation_module = params["vmfb"]
|
||||
self.iree_config = params["config"]
|
||||
self.temp_file_to_unlink = params["temp_file_to_unlink"]
|
||||
del params
|
||||
if self.vmfb_file == None:
|
||||
self.vmfb_file = export_iree_module_to_vmfb(
|
||||
mlir_module,
|
||||
device,
|
||||
".",
|
||||
self.mlir_dialect,
|
||||
extra_args=self.extra_args,
|
||||
)
|
||||
|
||||
def setup_cl(self, input_tensors):
|
||||
self.benchmark_cl = build_benchmark_args(
|
||||
@@ -124,41 +111,42 @@ class SharkBenchmarkRunner(SharkRunner):
|
||||
elif self.mlir_dialect in ["mhlo", "tf"]:
|
||||
return self.benchmark_tf(modelname)
|
||||
|
||||
def benchmark_torch(self, modelname, device="cpu"):
|
||||
def benchmark_torch(self, modelname):
|
||||
import torch
|
||||
from tank.model_utils import get_torch_model
|
||||
|
||||
# TODO: Pass this as an arg. currently the best way is to setup with BENCHMARK=1 if we want to use torch+cuda, else use cpu.
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
if device == "cuda":
|
||||
torch.set_default_device("cuda:0")
|
||||
# if self.enable_tf32:
|
||||
# torch.backends.cuda.matmul.allow_tf32 = True
|
||||
if self.device == "cuda":
|
||||
torch.set_default_tensor_type(torch.cuda.FloatTensor)
|
||||
if self.enable_tf32:
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
else:
|
||||
torch.set_default_dtype(torch.float32)
|
||||
torch.set_default_device("cpu")
|
||||
torch_device = torch.device("cuda:0" if device == "cuda" else "cpu")
|
||||
torch.set_default_tensor_type(torch.FloatTensor)
|
||||
torch_device = torch.device(
|
||||
"cuda:0" if self.device == "cuda" else "cpu"
|
||||
)
|
||||
HFmodel, input = get_torch_model(modelname, self.import_args)[:2]
|
||||
frontend_model = HFmodel.model
|
||||
frontend_model.to(torch_device)
|
||||
if device == "cuda":
|
||||
frontend_model.cuda()
|
||||
input.to(torch.device("cuda:0"))
|
||||
print(input)
|
||||
else:
|
||||
frontend_model.cpu()
|
||||
input.cpu()
|
||||
input.to(torch_device)
|
||||
|
||||
# TODO: re-enable as soon as pytorch CUDA context issues are resolved
|
||||
try:
|
||||
frontend_model = torch.compile(
|
||||
frontend_model, mode="max-autotune", backend="inductor"
|
||||
)
|
||||
except RuntimeError:
|
||||
frontend_model = HFmodel.model
|
||||
|
||||
for i in range(shark_args.num_warmup_iterations):
|
||||
frontend_model.forward(input)
|
||||
|
||||
if device == "cuda":
|
||||
if self.device == "cuda":
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
begin = time.time()
|
||||
for i in range(shark_args.num_iterations):
|
||||
out = frontend_model.forward(input)
|
||||
end = time.time()
|
||||
if device == "cuda":
|
||||
if self.device == "cuda":
|
||||
stats = torch.cuda.memory_stats()
|
||||
device_peak_b = stats["allocated_bytes.all.peak"]
|
||||
frontend_model.to(torch.device("cpu"))
|
||||
@@ -170,7 +158,7 @@ class SharkBenchmarkRunner(SharkRunner):
|
||||
print(
|
||||
f"Torch benchmark:{shark_args.num_iterations/(end-begin)} iter/second, Total Iterations:{shark_args.num_iterations}"
|
||||
)
|
||||
if device == "cuda":
|
||||
if self.device == "cuda":
|
||||
# Set device to CPU so we don't run into segfaults exiting pytest subprocesses.
|
||||
torch_device = torch.device("cpu")
|
||||
return [
|
||||
|
||||
@@ -11,8 +11,14 @@ from brevitas_examples.llm.llm_quant.quantize import quantize_model
|
||||
from brevitas_examples.llm.llm_quant.run_utils import get_model_impl
|
||||
|
||||
|
||||
# fmt: off
|
||||
def quant〇matmul_rhs_group_quant〡shape(lhs: List[int], rhs: List[int], rhs_scale: List[int], rhs_zero_point: List[int], rhs_bit_width: int, rhs_group_size: int) -> List[int]:
|
||||
def brevitas〇matmul_rhs_group_quant〡shape(
|
||||
lhs: List[int],
|
||||
rhs: List[int],
|
||||
rhs_scale: List[int],
|
||||
rhs_zero_point: List[int],
|
||||
rhs_bit_width: int,
|
||||
rhs_group_size: int,
|
||||
) -> List[int]:
|
||||
if len(lhs) == 3 and len(rhs) == 2:
|
||||
return [lhs[0], lhs[1], rhs[0]]
|
||||
elif len(lhs) == 2 and len(rhs) == 2:
|
||||
@@ -21,21 +27,30 @@ def quant〇matmul_rhs_group_quant〡shape(lhs: List[int], rhs: List[int], rhs_s
|
||||
raise ValueError("Input shapes not supported.")
|
||||
|
||||
|
||||
def quant〇matmul_rhs_group_quant〡dtype(lhs_rank_dtype: Tuple[int, int], rhs_rank_dtype: Tuple[int, int], rhs_scale_rank_dtype: Tuple[int, int], rhs_zero_point_rank_dtype: Tuple[int, int], rhs_bit_width: int, rhs_group_size: int) -> int:
|
||||
def brevitas〇matmul_rhs_group_quant〡dtype(
|
||||
lhs_rank_dtype: Tuple[int, int],
|
||||
rhs_rank_dtype: Tuple[int, int],
|
||||
rhs_scale_rank_dtype: Tuple[int, int],
|
||||
rhs_zero_point_rank_dtype: Tuple[int, int],
|
||||
rhs_bit_width: int,
|
||||
rhs_group_size: int,
|
||||
) -> int:
|
||||
# output dtype is the dtype of the lhs float input
|
||||
lhs_rank, lhs_dtype = lhs_rank_dtype
|
||||
return lhs_dtype
|
||||
|
||||
|
||||
def quant〇matmul_rhs_group_quant〡has_value_semantics(lhs, rhs, rhs_scale, rhs_zero_point, rhs_bit_width, rhs_group_size) -> None:
|
||||
def brevitas〇matmul_rhs_group_quant〡has_value_semantics(
|
||||
lhs, rhs, rhs_scale, rhs_zero_point, rhs_bit_width, rhs_group_size
|
||||
) -> None:
|
||||
return
|
||||
|
||||
|
||||
brevitas_matmul_rhs_group_quant_library = [
|
||||
quant〇matmul_rhs_group_quant〡shape,
|
||||
quant〇matmul_rhs_group_quant〡dtype,
|
||||
quant〇matmul_rhs_group_quant〡has_value_semantics]
|
||||
# fmt: on
|
||||
brevitas〇matmul_rhs_group_quant〡shape,
|
||||
brevitas〇matmul_rhs_group_quant〡dtype,
|
||||
brevitas〇matmul_rhs_group_quant〡has_value_semantics,
|
||||
]
|
||||
|
||||
|
||||
def load_vmfb(extended_model_name, device, mlir_dialect, extra_args=[]):
|
||||
@@ -107,7 +122,7 @@ def compile_int_precision(
|
||||
torchscript_module,
|
||||
inputs,
|
||||
output_type="torch",
|
||||
backend_legal_ops=["quant.matmul_rhs_group_quant"],
|
||||
backend_legal_ops=["brevitas.matmul_rhs_group_quant"],
|
||||
extra_library=brevitas_matmul_rhs_group_quant_library,
|
||||
use_tracing=False,
|
||||
verbose=False,
|
||||
|
||||
@@ -111,20 +111,22 @@ os.makedirs(WORKDIR, exist_ok=True)
|
||||
def check_dir_exists(model_name, frontend="torch", dynamic=""):
|
||||
model_dir = os.path.join(WORKDIR, model_name)
|
||||
|
||||
# Remove the _tf keyword from end only for non-SD models.
|
||||
if not any(model in model_name for model in ["clip", "unet", "vae"]):
|
||||
if frontend in ["tf", "tensorflow"]:
|
||||
model_name = model_name[:-3]
|
||||
elif frontend in ["tflite"]:
|
||||
model_name = model_name[:-7]
|
||||
elif frontend in ["torch", "pytorch"]:
|
||||
model_name = model_name[:-6]
|
||||
|
||||
model_mlir_file_name = f"{model_name}{dynamic}_{frontend}.mlir"
|
||||
# Remove the _tf keyword from end.
|
||||
if frontend in ["tf", "tensorflow"]:
|
||||
model_name = model_name[:-3]
|
||||
elif frontend in ["tflite"]:
|
||||
model_name = model_name[:-7]
|
||||
elif frontend in ["torch", "pytorch"]:
|
||||
model_name = model_name[:-6]
|
||||
|
||||
if os.path.isdir(model_dir):
|
||||
if (
|
||||
os.path.isfile(os.path.join(model_dir, model_mlir_file_name))
|
||||
os.path.isfile(
|
||||
os.path.join(
|
||||
model_dir,
|
||||
model_name + dynamic + "_" + str(frontend) + ".mlir",
|
||||
)
|
||||
)
|
||||
and os.path.isfile(os.path.join(model_dir, "function_name.npy"))
|
||||
and os.path.isfile(os.path.join(model_dir, "inputs.npz"))
|
||||
and os.path.isfile(os.path.join(model_dir, "golden_out.npz"))
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
import re
|
||||
import json
|
||||
import numpy as np
|
||||
|
||||
import torch_mlir
|
||||
from iree.compiler import compile_str
|
||||
from shark.shark_importer import import_with_fx, get_f16_inputs
|
||||
@@ -13,7 +11,6 @@ class GenerateConfigFile:
|
||||
model,
|
||||
num_sharding_stages: int,
|
||||
sharding_stages_id: list[str],
|
||||
units_in_each_stage: list[int],
|
||||
model_input=None,
|
||||
config_file_path="model_config.json",
|
||||
):
|
||||
@@ -25,16 +22,13 @@ class GenerateConfigFile:
|
||||
), "Number of sharding stages should be equal to the list of their ID"
|
||||
self.model_input = model_input
|
||||
self.config_file_path = config_file_path
|
||||
# (Nithin) this is a quick fix - revisit and rewrite
|
||||
self.units_in_each_stage = np.array(units_in_each_stage)
|
||||
self.track_loop = np.zeros(len(self.sharding_stages_id)).astype(int)
|
||||
|
||||
def split_into_dispatches(
|
||||
self,
|
||||
backend,
|
||||
fx_tracing_required=False,
|
||||
fx_tracing_required=True,
|
||||
f16_model=False,
|
||||
torch_mlir_tracing=True,
|
||||
torch_mlir_tracing=False,
|
||||
):
|
||||
graph_for_compilation = self.model
|
||||
if fx_tracing_required:
|
||||
@@ -101,17 +95,7 @@ class GenerateConfigFile:
|
||||
if substring_before_final_period in model_dictionary:
|
||||
del model_dictionary[substring_before_final_period]
|
||||
|
||||
# layer_dict = {n: "None" for n in self.sharding_stages_id}
|
||||
|
||||
# By default embed increasing device id's for each layer
|
||||
increasing_wraparound_idx_list = (
|
||||
self.track_loop % self.units_in_each_stage
|
||||
)
|
||||
layer_dict = {
|
||||
n: int(increasing_wraparound_idx_list[idx][0][0])
|
||||
for idx, n in enumerate(self.sharding_stages_id)
|
||||
}
|
||||
self.track_loop += 1
|
||||
layer_dict = {n: "None" for n in self.sharding_stages_id}
|
||||
model_dictionary[name] = layer_dict
|
||||
|
||||
self.generate_json(model_dictionary)
|
||||
@@ -119,29 +103,3 @@ class GenerateConfigFile:
|
||||
def generate_json(self, artifacts):
|
||||
with open(self.config_file_path, "w") as outfile:
|
||||
json.dump(artifacts, outfile)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import torch
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
hf_model_path = "TheBloke/vicuna-7B-1.1-HF"
|
||||
tokenizer = AutoTokenizer.from_pretrained(hf_model_path, use_fast=False)
|
||||
compilation_prompt = "".join(["0" for _ in range(17)])
|
||||
compilation_input_ids = tokenizer(
|
||||
compilation_prompt,
|
||||
return_tensors="pt",
|
||||
).input_ids
|
||||
compilation_input_ids = torch.tensor(compilation_input_ids).reshape(
|
||||
[1, 19]
|
||||
)
|
||||
firstVicunaCompileInput = (compilation_input_ids,)
|
||||
from apps.language_models.src.model_wrappers.vicuna_model import (
|
||||
FirstVicuna,
|
||||
SecondVicuna,
|
||||
CombinedModel,
|
||||
)
|
||||
|
||||
model = CombinedModel()
|
||||
c = GenerateConfigFile(model, 1, ["gpu_id"], firstVicunaCompileInput)
|
||||
c.split_into_layers()
|
||||
|
||||
@@ -612,7 +612,7 @@ def import_with_fx(
|
||||
replace_call_fn_target(
|
||||
fx_g,
|
||||
src=matmul_rhs_group_quant_placeholder,
|
||||
target=torch.ops.quant.matmul_rhs_group_quant,
|
||||
target=torch.ops.brevitas.matmul_rhs_group_quant,
|
||||
)
|
||||
|
||||
fx_g.recompile()
|
||||
|
||||
@@ -141,10 +141,6 @@ class SharkInference:
|
||||
def __call__(self, function_name: str, inputs: tuple, send_to_host=True):
|
||||
return self.shark_runner.run(function_name, inputs, send_to_host)
|
||||
|
||||
# forward function.
|
||||
def forward(self, inputs: tuple, send_to_host=True):
|
||||
return self.shark_runner.run("forward", inputs, send_to_host)
|
||||
|
||||
# Get all function names defined within the compiled module.
|
||||
def get_functions_in_module(self):
|
||||
return self.shark_runner.get_functions_in_module()
|
||||
|
||||
@@ -13,6 +13,7 @@ google/vit-base-patch16-224,stablehlo,tf,1e-2,1e-3,tf_vit,nhcw-nhwc,False,False,
|
||||
microsoft/MiniLM-L12-H384-uncased,stablehlo,tf,1e-2,1e-3,tf_hf,None,True,False,False,"Fails during iree-compile.",""
|
||||
microsoft/layoutlm-base-uncased,stablehlo,tf,1e-2,1e-3,default,None,False,False,False,"",""
|
||||
microsoft/mpnet-base,stablehlo,tf,1e-2,1e-2,default,None,True,True,True,"",""
|
||||
albert-base-v2,linalg,torch,1e-2,1e-3,default,None,True,True,True,"issue with aten.tanh in torch-mlir",""
|
||||
alexnet,linalg,torch,1e-2,1e-3,default,None,True,True,False,"https://github.com/nod-ai/SHARK/issues/879",""
|
||||
bert-base-cased,linalg,torch,1e-2,1e-3,default,None,False,True,False,"",""
|
||||
bert-base-uncased,linalg,torch,1e-2,1e-3,default,None,False,True,False,"",""
|
||||
@@ -29,7 +30,7 @@ nvidia/mit-b0,linalg,torch,1e-2,1e-3,default,None,True,True,True,"https://github
|
||||
resnet101,linalg,torch,1e-2,1e-3,default,nhcw-nhwc/img2col,True,False,False,"","macos"
|
||||
resnet18,linalg,torch,1e-2,1e-3,default,None,True,True,False,"","macos"
|
||||
resnet50,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,False,False,False,"","macos"
|
||||
resnet50_fp16,linalg,torch,1e-2,1e-2,default,nhcw-nhwc/img2col,True,True,True,"Numerics issues, awaiting cuda-independent fp16 integration",""
|
||||
resnet50_fp16,linalg,torch,1e-2,1e-2,default,nhcw-nhwc/img2col,True,False,True,"",""
|
||||
squeezenet1_0,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,False,False,False,"","macos"
|
||||
wide_resnet50_2,linalg,torch,1e-2,1e-3,default,nhcw-nhwc/img2col,True,False,False,"","macos"
|
||||
efficientnet-v2-s,stablehlo,tf,1e-02,1e-3,default,nhcw-nhwc,False,False,False,"","macos"
|
||||
@@ -43,3 +44,4 @@ t5-base,linalg,torch,1e-2,1e-3,default,None,True,True,True,"Inputs for seq2seq m
|
||||
t5-base,stablehlo,tf,1e-2,1e-3,default,None,False,False,False,"","macos"
|
||||
t5-large,linalg,torch,1e-2,1e-3,default,None,True,True,True,"Inputs for seq2seq models in torch currently unsupported","macos"
|
||||
t5-large,stablehlo,tf,1e-2,1e-3,default,None,False,False,False,"","macos"
|
||||
stabilityai/stable-diffusion-2-1-base,linalg,torch,1e-3,1e-3,default,None,True,False,False,"","macos"
|
||||
|
||||
|
@@ -16,6 +16,12 @@ import subprocess as sp
|
||||
import hashlib
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
from apps.stable_diffusion.src.models import (
|
||||
model_wrappers as mw,
|
||||
)
|
||||
from apps.stable_diffusion.src.utils.stable_args import (
|
||||
args,
|
||||
)
|
||||
|
||||
|
||||
def create_hash(file_name):
|
||||
@@ -54,6 +60,31 @@ def save_torch_model(torch_model_list, local_tank_cache, import_args):
|
||||
print("generating artifacts for: " + torch_model_name)
|
||||
model = None
|
||||
input = None
|
||||
if model_type == "stable_diffusion":
|
||||
args.use_tuned = False
|
||||
args.import_mlir = True
|
||||
args.local_tank_cache = local_tank_cache
|
||||
|
||||
precision_values = ["fp16"]
|
||||
seq_lengths = [64, 77]
|
||||
for precision_value in precision_values:
|
||||
args.precision = precision_value
|
||||
for length in seq_lengths:
|
||||
model = mw.SharkifyStableDiffusionModel(
|
||||
model_id=torch_model_name,
|
||||
custom_weights="",
|
||||
precision=precision_value,
|
||||
max_len=length,
|
||||
width=512,
|
||||
height=512,
|
||||
use_base_vae=False,
|
||||
custom_vae="",
|
||||
debug=True,
|
||||
sharktank_dir=local_tank_cache,
|
||||
generate_vmfb=False,
|
||||
)
|
||||
model()
|
||||
continue
|
||||
if model_type == "vision":
|
||||
model, input, _ = get_vision_model(
|
||||
torch_model_name, import_args
|
||||
@@ -72,11 +103,10 @@ def save_torch_model(torch_model_list, local_tank_cache, import_args):
|
||||
model, input, _ = get_hf_img_cls_model(
|
||||
torch_model_name, import_args
|
||||
)
|
||||
elif model_type == "fp16":
|
||||
model, input, _ = get_fp16_model(torch_model_name, import_args)
|
||||
torch_model_name = torch_model_name.replace("/", "_")
|
||||
if import_args["batch_size"] > 1:
|
||||
print(
|
||||
f"Batch size for this model set to {import_args['batch_size']}"
|
||||
)
|
||||
if import_args["batch_size"] != 1:
|
||||
torch_model_dir = os.path.join(
|
||||
local_tank_cache,
|
||||
str(torch_model_name)
|
||||
@@ -361,7 +391,7 @@ if __name__ == "__main__":
|
||||
|
||||
# old_import_args = parser.parse_import_args()
|
||||
import_args = {
|
||||
"batch_size": 1,
|
||||
"batch_size": "1",
|
||||
}
|
||||
print(import_args)
|
||||
home = str(Path.home())
|
||||
@@ -374,6 +404,11 @@ if __name__ == "__main__":
|
||||
os.path.dirname(__file__), "tflite", "tflite_model_list.csv"
|
||||
)
|
||||
|
||||
save_torch_model(
|
||||
os.path.join(os.path.dirname(__file__), "torch_sd_list.csv"),
|
||||
WORKDIR,
|
||||
import_args,
|
||||
)
|
||||
save_torch_model(torch_model_csv, WORKDIR, import_args)
|
||||
# save_tf_model(tf_model_csv, WORKDIR, import_args)
|
||||
# save_tflite_model(tflite_model_csv, WORKDIR, import_args)
|
||||
save_tf_model(tf_model_csv, WORKDIR, import_args)
|
||||
save_tflite_model(tflite_model_csv, WORKDIR, import_args)
|
||||
|
||||
@@ -278,7 +278,7 @@ def get_vision_model(torch_model, import_args):
|
||||
int(import_args["batch_size"]), 3, *input_image_size
|
||||
)
|
||||
actual_out = model(test_input)
|
||||
if fp16_model == True:
|
||||
if fp16_model is not None:
|
||||
test_input_fp16 = test_input.to(
|
||||
device=torch.device("cuda"), dtype=torch.half
|
||||
)
|
||||
|
||||
@@ -145,7 +145,6 @@ class SharkModuleTester:
|
||||
shark_args.shark_prefix = self.shark_tank_prefix
|
||||
shark_args.local_tank_cache = self.local_tank_cache
|
||||
shark_args.dispatch_benchmarks = self.benchmark_dispatches
|
||||
shark_args.enable_tf32 = self.tf32
|
||||
|
||||
if self.benchmark_dispatches is not None:
|
||||
_m = self.config["model_name"].split("/")
|
||||
@@ -217,12 +216,10 @@ class SharkModuleTester:
|
||||
|
||||
result = shark_module(func_name, inputs)
|
||||
golden_out, result = self.postprocess_outputs(golden_out, result)
|
||||
if self.tf32 == True:
|
||||
print(
|
||||
"Validating with relaxed tolerances for TensorFloat32 calculations."
|
||||
)
|
||||
self.config["atol"] = 1e-01
|
||||
self.config["rtol"] = 1e-02
|
||||
if self.tf32 == "true":
|
||||
print("Validating with relaxed tolerances.")
|
||||
atol = 1e-02
|
||||
rtol = 1e-03
|
||||
try:
|
||||
np.testing.assert_allclose(
|
||||
golden_out,
|
||||
@@ -257,6 +254,9 @@ class SharkModuleTester:
|
||||
model_config = {
|
||||
"batch_size": self.batch_size,
|
||||
}
|
||||
shark_args.enable_tf32 = self.tf32
|
||||
if shark_args.enable_tf32 == True:
|
||||
shark_module.compile()
|
||||
|
||||
shark_args.onnx_bench = self.onnx_bench
|
||||
shark_module.shark_runner.benchmark_all_csv(
|
||||
|
||||
@@ -5,6 +5,7 @@ microsoft/MiniLM-L12-H384-uncased,True,hf,True,linalg,False,66M,"nlp;bert-varian
|
||||
bert-base-uncased,True,hf,True,linalg,False,109M,"nlp;bert-variant;transformer-encoder","12 layers; 768 hidden; 12 attention heads"
|
||||
bert-base-cased,True,hf,True,linalg,False,109M,"nlp;bert-variant;transformer-encoder","12 layers; 768 hidden; 12 attention heads"
|
||||
google/mobilebert-uncased,True,hf,True,linalg,False,25M,"nlp,bert-variant,transformer-encoder,mobile","24 layers, 512 hidden size, 128 embedding"
|
||||
alexnet,False,vision,True,linalg,False,61M,"cnn,parallel-layers","The CNN that revolutionized computer vision (move away from hand-crafted features to neural networks),10 years old now and probably no longer used in prod."
|
||||
resnet18,False,vision,True,linalg,False,11M,"cnn,image-classification,residuals,resnet-variant","1 7x7 conv2d and the rest are 3x3 conv2d"
|
||||
resnet50,False,vision,True,linalg,False,23M,"cnn,image-classification,residuals,resnet-variant","Bottlenecks with only conv2d (1x1 conv -> 3x3 conv -> 1x1 conv blocks)"
|
||||
resnet101,False,vision,True,linalg,False,29M,"cnn,image-classification,residuals,resnet-variant","Bottlenecks with only conv2d (1x1 conv -> 3x3 conv -> 1x1 conv blocks)"
|
||||
@@ -17,9 +18,11 @@ facebook/deit-small-distilled-patch16-224,True,hf_img_cls,False,linalg,False,22M
|
||||
microsoft/beit-base-patch16-224-pt22k-ft22k,True,hf_img_cls,False,linalg,False,86M,"image-classification,transformer-encoder,bert-variant,vision-transformer",N/A
|
||||
nvidia/mit-b0,True,hf_img_cls,False,linalg,False,3.7M,"image-classification,transformer-encoder",SegFormer
|
||||
mnasnet1_0,False,vision,True,linalg,False,-,"cnn, torchvision, mobile, architecture-search","Outperforms other mobile CNNs on Accuracy vs. Latency"
|
||||
resnet50_fp16,False,vision,True,linalg,False,23M,"cnn,image-classification,residuals,resnet-variant","Bottlenecks with only conv2d (1x1 conv -> 3x3 conv -> 1x1 conv blocks)"
|
||||
bert-base-uncased_fp16,True,fp16,False,linalg,False,109M,"nlp;bert-variant;transformer-encoder","12 layers; 768 hidden; 12 attention heads"
|
||||
bert-large-uncased,True,hf,True,linalg,False,330M,"nlp;bert-variant;transformer-encoder","24 layers, 1024 hidden units, 16 attention heads"
|
||||
bert-base-uncased,True,hf,False,stablehlo,False,109M,"nlp;bert-variant;transformer-encoder","12 layers; 768 hidden; 12 attention heads"
|
||||
gpt2,True,hf_causallm,False,stablehlo,True,125M,"nlp;transformer-encoder","-"
|
||||
facebook/opt-125m,True,hf,False,stablehlo,True,125M,"nlp;transformer-encoder","-"
|
||||
distilgpt2,True,hf,False,stablehlo,True,88M,"nlp;transformer-encoder","-"
|
||||
microsoft/deberta-v3-base,True,hf,False,stablehlo,True,88M,"nlp;transformer-encoder","-"
|
||||
microsoft/deberta-v3-base,True,hf,False,stablehlo,True,88M,"nlp;transformer-encoder","-"
|
||||
|
Reference in New Issue
Block a user