Compare commits

..

1 Commits

Author SHA1 Message Date
Elias Joseph
16daba99fe wip script for lowering dlrm training 2023-09-06 03:48:20 +00:00
100 changed files with 4537 additions and 7320 deletions

View File

@@ -51,11 +51,11 @@ jobs:
run: |
./setup_venv.ps1
$env:SHARK_PACKAGE_VERSION=${{ env.package_version }}
pip wheel -v -w dist . --pre -f https://download.pytorch.org/whl/nightly/cpu -f https://llvm.github.io/torch-mlir/package-index/ -f https://nod-ai.github.io/SRT/pip-release-links.html
pip wheel -v -w dist . --pre -f https://download.pytorch.org/whl/nightly/cpu -f https://llvm.github.io/torch-mlir/package-index/ -f https://nod-ai.github.io/SHARK-Runtime/pip-release-links.html
python process_skipfiles.py
pyinstaller .\apps\stable_diffusion\shark_sd.spec
mv ./dist/nodai_shark_studio.exe ./dist/nodai_shark_studio_${{ env.package_version_ }}.exe
signtool sign /f c:\g\shark_02152023.cer /fd certHash /csp "eToken Base Cryptographic Provider" /k "${{ secrets.CI_CERT }}" ./dist/nodai_shark_studio_${{ env.package_version_ }}.exe
signtool sign /f c:\g\shark_02152023.cer /csp "eToken Base Cryptographic Provider" /k "${{ secrets.CI_CERT }}" ./dist/nodai_shark_studio_${{ env.package_version_ }}.exe
- name: Upload Release Assets
id: upload-release-assets
@@ -104,7 +104,7 @@ jobs:
echo "DATE=$(date +'%Y-%m-%d')" >> $GITHUB_ENV
python -m pip install --upgrade pip
python -m pip install flake8 pytest toml
if [ -f requirements.txt ]; then pip install -r requirements.txt -f https://llvm.github.io/torch-mlir/package-index/ -f https://nod-ai.github.io/SRT/pip-release-links.html; fi
if [ -f requirements.txt ]; then pip install -r requirements.txt -f https://llvm.github.io/torch-mlir/package-index/ -f https://nod-ai.github.io/SHARK-Runtime/pip-release-links.html; fi
- name: Lint with flake8
run: |
# stop the build if there are Python syntax errors or undefined names
@@ -144,7 +144,7 @@ jobs:
source shark.venv/bin/activate
package_version="$(printf '%(%Y%m%d)T.${{ github.run_number }}')"
SHARK_PACKAGE_VERSION=${package_version} \
pip wheel -v -w wheelhouse . --pre -f https://download.pytorch.org/whl/nightly/torch -f https://llvm.github.io/torch-mlir/package-index/ -f https://nod-ai.github.io/SRT/pip-release-links.html
pip wheel -v -w wheelhouse . --pre -f https://download.pytorch.org/whl/nightly/torch -f https://llvm.github.io/torch-mlir/package-index/ -f https://nod-ai.github.io/SHARK-Runtime/pip-release-links.html
# Install the built wheel
pip install ./wheelhouse/nodai*
# Validate the Models

6
.gitignore vendored
View File

@@ -193,9 +193,3 @@ stencil_annotator/
# For DocuChat
apps/language_models/langchain/user_path/
db_dir_UserData
# Embeded browser cache and other
apps/stable_diffusion/web/EBWebView/
# Llama2 tokenizer configs
llama2_tokenizer_configs/

2
.gitmodules vendored
View File

@@ -1,4 +1,4 @@
[submodule "inference/thirdparty/shark-runtime"]
path = inference/thirdparty/shark-runtime
url =https://github.com/nod-ai/SRT.git
url =https://github.com/nod-ai/SHARK-Runtime.git
branch = shark-06032022

View File

@@ -10,7 +10,7 @@ High Performance Machine Learning Distribution
<summary>Prerequisites - Drivers </summary>
#### Install your Windows hardware drivers
* [AMD RDNA Users] Download the latest driver (23.2.1 is the oldest supported) [here](https://www.amd.com/en/support).
* [AMD RDNA Users] Download the latest driver [here](https://www.amd.com/en/support/kb/release-notes/rn-rad-win-23-2-1).
* [macOS Users] Download and install the 1.3.216 Vulkan SDK from [here](https://sdk.lunarg.com/sdk/download/1.3.216.0/mac/vulkansdk-macos-1.3.216.0.dmg). Newer versions of the SDK will not work.
* [Nvidia Users] Download and install the latest CUDA / Vulkan drivers from [here](https://developer.nvidia.com/cuda-downloads)
@@ -170,7 +170,7 @@ python -m pip install --upgrade pip
This step pip installs SHARK and related packages on Linux Python 3.8, 3.10 and 3.11 and macOS / Windows Python 3.11
```shell
pip install nodai-shark -f https://nod-ai.github.io/SHARK/package-index/ -f https://llvm.github.io/torch-mlir/package-index/ -f https://nod-ai.github.io/SRT/pip-release-links.html --extra-index-url https://download.pytorch.org/whl/nightly/cpu
pip install nodai-shark -f https://nod-ai.github.io/SHARK/package-index/ -f https://llvm.github.io/torch-mlir/package-index/ -f https://nod-ai.github.io/SHARK-Runtime/pip-release-links.html --extra-index-url https://download.pytorch.org/whl/nightly/cpu
```
### Run shark tank model tests.

View File

@@ -1,3 +1,4 @@
"""Load question answering chains."""
from __future__ import annotations
from typing import (
Any,
@@ -10,34 +11,23 @@ from typing import (
Union,
Protocol,
)
import inspect
import json
import warnings
from pathlib import Path
import yaml
from abc import ABC, abstractmethod
import langchain
from langchain.base_language import BaseLanguageModel
from langchain.callbacks.base import BaseCallbackManager
from langchain.chains.question_answering import stuff_prompt
from langchain.prompts.base import BasePromptTemplate
from langchain.docstore.document import Document
from abc import ABC, abstractmethod
from langchain.chains.base import Chain
from langchain.callbacks.manager import (
CallbackManager,
CallbackManagerForChainRun,
Callbacks,
)
from langchain.load.serializable import Serializable
from langchain.schema import RUN_KEY, BaseMemory, RunInfo
from langchain.input import get_colored_text
from langchain.load.dump import dumpd
from langchain.prompts.prompt import PromptTemplate
from langchain.schema import LLMResult, PromptValue
from pydantic import Extra, Field, root_validator, validator
def _get_verbosity() -> bool:
return langchain.verbose
from pydantic import Extra, Field, root_validator
def format_document(doc: Document, prompt: BasePromptTemplate) -> str:
@@ -58,413 +48,6 @@ def format_document(doc: Document, prompt: BasePromptTemplate) -> str:
return prompt.format(**document_info)
class Chain(Serializable, ABC):
"""Base interface that all chains should implement."""
memory: Optional[BaseMemory] = None
callbacks: Callbacks = Field(default=None, exclude=True)
callback_manager: Optional[BaseCallbackManager] = Field(
default=None, exclude=True
)
verbose: bool = Field(
default_factory=_get_verbosity
) # Whether to print the response text
tags: Optional[List[str]] = None
class Config:
"""Configuration for this pydantic object."""
arbitrary_types_allowed = True
@property
def _chain_type(self) -> str:
raise NotImplementedError("Saving not supported for this chain type.")
@root_validator()
def raise_deprecation(cls, values: Dict) -> Dict:
"""Raise deprecation warning if callback_manager is used."""
if values.get("callback_manager") is not None:
warnings.warn(
"callback_manager is deprecated. Please use callbacks instead.",
DeprecationWarning,
)
values["callbacks"] = values.pop("callback_manager", None)
return values
@validator("verbose", pre=True, always=True)
def set_verbose(cls, verbose: Optional[bool]) -> bool:
"""If verbose is None, set it.
This allows users to pass in None as verbose to access the global setting.
"""
if verbose is None:
return _get_verbosity()
else:
return verbose
@property
@abstractmethod
def input_keys(self) -> List[str]:
"""Input keys this chain expects."""
@property
@abstractmethod
def output_keys(self) -> List[str]:
"""Output keys this chain expects."""
def _validate_inputs(self, inputs: Dict[str, Any]) -> None:
"""Check that all inputs are present."""
missing_keys = set(self.input_keys).difference(inputs)
if missing_keys:
raise ValueError(f"Missing some input keys: {missing_keys}")
def _validate_outputs(self, outputs: Dict[str, Any]) -> None:
missing_keys = set(self.output_keys).difference(outputs)
if missing_keys:
raise ValueError(f"Missing some output keys: {missing_keys}")
@abstractmethod
def _call(
self,
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
"""Run the logic of this chain and return the output."""
def __call__(
self,
inputs: Union[Dict[str, Any], Any],
return_only_outputs: bool = False,
callbacks: Callbacks = None,
*,
tags: Optional[List[str]] = None,
include_run_info: bool = False,
) -> Dict[str, Any]:
"""Run the logic of this chain and add to output if desired.
Args:
inputs: Dictionary of inputs, or single input if chain expects
only one param.
return_only_outputs: boolean for whether to return only outputs in the
response. If True, only new keys generated by this chain will be
returned. If False, both input keys and new keys generated by this
chain will be returned. Defaults to False.
callbacks: Callbacks to use for this chain run. If not provided, will
use the callbacks provided to the chain.
include_run_info: Whether to include run info in the response. Defaults
to False.
"""
input_docs = inputs["input_documents"]
missing_keys = set(self.input_keys).difference(inputs)
if missing_keys:
raise ValueError(f"Missing some input keys: {missing_keys}")
callback_manager = CallbackManager.configure(
callbacks, self.callbacks, self.verbose, tags, self.tags
)
run_manager = callback_manager.on_chain_start(
dumpd(self),
inputs,
)
if "is_first" in inputs.keys() and not inputs["is_first"]:
run_manager_ = run_manager
input_list = [inputs]
stop = None
prompts = []
for inputs in input_list:
selected_inputs = {
k: inputs[k] for k in self.prompt.input_variables
}
prompt = self.prompt.format_prompt(**selected_inputs)
_colored_text = get_colored_text(prompt.to_string(), "green")
_text = "Prompt after formatting:\n" + _colored_text
if run_manager_:
run_manager_.on_text(_text, end="\n", verbose=self.verbose)
if "stop" in inputs and inputs["stop"] != stop:
raise ValueError(
"If `stop` is present in any inputs, should be present in all."
)
prompts.append(prompt)
prompt_strings = [p.to_string() for p in prompts]
prompts = prompt_strings
callbacks = run_manager_.get_child() if run_manager_ else None
tags = None
"""Run the LLM on the given prompt and input."""
# If string is passed in directly no errors will be raised but outputs will
# not make sense.
if not isinstance(prompts, list):
raise ValueError(
"Argument 'prompts' is expected to be of type List[str], received"
f" argument of type {type(prompts)}."
)
params = self.llm.dict()
params["stop"] = stop
options = {"stop": stop}
disregard_cache = self.llm.cache is not None and not self.llm.cache
callback_manager = CallbackManager.configure(
callbacks,
self.llm.callbacks,
self.llm.verbose,
tags,
self.llm.tags,
)
if langchain.llm_cache is None or disregard_cache:
# This happens when langchain.cache is None, but self.cache is True
if self.llm.cache is not None and self.cache:
raise ValueError(
"Asked to cache, but no cache found at `langchain.cache`."
)
run_manager_ = callback_manager.on_llm_start(
dumpd(self),
prompts,
invocation_params=params,
options=options,
)
generations = []
for prompt in prompts:
inputs_ = prompt
num_workers = None
batch_size = None
if num_workers is None:
if self.llm.pipeline._num_workers is None:
num_workers = 0
else:
num_workers = self.llm.pipeline._num_workers
if batch_size is None:
if self.llm.pipeline._batch_size is None:
batch_size = 1
else:
batch_size = self.llm.pipeline._batch_size
preprocess_params = {}
generate_kwargs = {}
preprocess_params.update(generate_kwargs)
forward_params = generate_kwargs
postprocess_params = {}
# Fuse __init__ params and __call__ params without modifying the __init__ ones.
preprocess_params = {
**self.llm.pipeline._preprocess_params,
**preprocess_params,
}
forward_params = {
**self.llm.pipeline._forward_params,
**forward_params,
}
postprocess_params = {
**self.llm.pipeline._postprocess_params,
**postprocess_params,
}
self.llm.pipeline.call_count += 1
if (
self.llm.pipeline.call_count > 10
and self.llm.pipeline.framework == "pt"
and self.llm.pipeline.device.type == "cuda"
):
warnings.warn(
"You seem to be using the pipelines sequentially on GPU. In order to maximize efficiency please use a"
" dataset",
UserWarning,
)
model_inputs = self.llm.pipeline.preprocess(
inputs_, **preprocess_params
)
model_outputs = self.llm.pipeline.forward(
model_inputs, **forward_params
)
model_outputs["process"] = False
return model_outputs
output = LLMResult(generations=generations)
run_manager_.on_llm_end(output)
if run_manager_:
output.run = RunInfo(run_id=run_manager_.run_id)
response = output
outputs = [
# Get the text of the top generated string.
{self.output_key: generation[0].text}
for generation in response.generations
][0]
run_manager.on_chain_end(outputs)
final_outputs: Dict[str, Any] = self.prep_outputs(
inputs, outputs, return_only_outputs
)
if include_run_info:
final_outputs[RUN_KEY] = RunInfo(run_id=run_manager.run_id)
return final_outputs
else:
_run_manager = (
run_manager or CallbackManagerForChainRun.get_noop_manager()
)
docs = inputs[self.input_key]
# Other keys are assumed to be needed for LLM prediction
other_keys = {
k: v for k, v in inputs.items() if k != self.input_key
}
doc_strings = [
format_document(doc, self.document_prompt) for doc in docs
]
# Join the documents together to put them in the prompt.
inputs = {
k: v
for k, v in other_keys.items()
if k in self.llm_chain.prompt.input_variables
}
inputs[self.document_variable_name] = self.document_separator.join(
doc_strings
)
inputs["is_first"] = False
inputs["input_documents"] = input_docs
# Call predict on the LLM.
output = self.llm_chain(inputs, callbacks=_run_manager.get_child())
if "process" in output.keys() and not output["process"]:
return output
output = output[self.llm_chain.output_key]
extra_return_dict = {}
extra_return_dict[self.output_key] = output
outputs = extra_return_dict
run_manager.on_chain_end(outputs)
final_outputs: Dict[str, Any] = self.prep_outputs(
inputs, outputs, return_only_outputs
)
if include_run_info:
final_outputs[RUN_KEY] = RunInfo(run_id=run_manager.run_id)
return final_outputs
def prep_outputs(
self,
inputs: Dict[str, str],
outputs: Dict[str, str],
return_only_outputs: bool = False,
) -> Dict[str, str]:
"""Validate and prep outputs."""
self._validate_outputs(outputs)
if self.memory is not None:
self.memory.save_context(inputs, outputs)
if return_only_outputs:
return outputs
else:
return {**inputs, **outputs}
def prep_inputs(
self, inputs: Union[Dict[str, Any], Any]
) -> Dict[str, str]:
"""Validate and prep inputs."""
if not isinstance(inputs, dict):
_input_keys = set(self.input_keys)
if self.memory is not None:
# If there are multiple input keys, but some get set by memory so that
# only one is not set, we can still figure out which key it is.
_input_keys = _input_keys.difference(
self.memory.memory_variables
)
if len(_input_keys) != 1:
raise ValueError(
f"A single string input was passed in, but this chain expects "
f"multiple inputs ({_input_keys}). When a chain expects "
f"multiple inputs, please call it by passing in a dictionary, "
"eg `chain({'foo': 1, 'bar': 2})`"
)
inputs = {list(_input_keys)[0]: inputs}
if self.memory is not None:
external_context = self.memory.load_memory_variables(inputs)
inputs = dict(inputs, **external_context)
self._validate_inputs(inputs)
return inputs
def apply(
self, input_list: List[Dict[str, Any]], callbacks: Callbacks = None
) -> List[Dict[str, str]]:
"""Call the chain on all inputs in the list."""
return [self(inputs, callbacks=callbacks) for inputs in input_list]
def run(
self,
*args: Any,
callbacks: Callbacks = None,
tags: Optional[List[str]] = None,
**kwargs: Any,
) -> str:
"""Run the chain as text in, text out or multiple variables, text out."""
if len(self.output_keys) != 1:
raise ValueError(
f"`run` not supported when there is not exactly "
f"one output key. Got {self.output_keys}."
)
if args and not kwargs:
if len(args) != 1:
raise ValueError(
"`run` supports only one positional argument."
)
return self(args[0], callbacks=callbacks, tags=tags)[
self.output_keys[0]
]
if kwargs and not args:
return self(kwargs, callbacks=callbacks, tags=tags)[
self.output_keys[0]
]
if not kwargs and not args:
raise ValueError(
"`run` supported with either positional arguments or keyword arguments,"
" but none were provided."
)
raise ValueError(
f"`run` supported with either positional arguments or keyword arguments"
f" but not both. Got args: {args} and kwargs: {kwargs}."
)
def dict(self, **kwargs: Any) -> Dict:
"""Return dictionary representation of chain."""
if self.memory is not None:
raise ValueError("Saving of memory is not yet supported.")
_dict = super().dict()
_dict["_type"] = self._chain_type
return _dict
def save(self, file_path: Union[Path, str]) -> None:
"""Save the chain.
Args:
file_path: Path to file to save the chain to.
Example:
.. code-block:: python
chain.save(file_path="path/chain.yaml")
"""
# Convert file to Path object.
if isinstance(file_path, str):
save_path = Path(file_path)
else:
save_path = file_path
directory_path = save_path.parent
directory_path.mkdir(parents=True, exist_ok=True)
# Fetch dictionary to save
chain_dict = self.dict()
if save_path.suffix == ".json":
with open(file_path, "w") as f:
json.dump(chain_dict, f, indent=4)
elif save_path.suffix == ".yaml":
with open(file_path, "w") as f:
yaml.dump(chain_dict, f, default_flow_style=False)
else:
raise ValueError(f"{save_path} must be json or yaml")
class BaseCombineDocumentsChain(Chain, ABC):
"""Base interface for chains combining documents."""
@@ -496,6 +79,12 @@ class BaseCombineDocumentsChain(Chain, ABC):
"""
return None
@abstractmethod
def combine_docs(
self, docs: List[Document], **kwargs: Any
) -> Tuple[str, dict]:
"""Combine documents into a single string."""
def _call(
self,
inputs: Dict[str, List[Document]],
@@ -507,49 +96,13 @@ class BaseCombineDocumentsChain(Chain, ABC):
docs = inputs[self.input_key]
# Other keys are assumed to be needed for LLM prediction
other_keys = {k: v for k, v in inputs.items() if k != self.input_key}
doc_strings = [
format_document(doc, self.document_prompt) for doc in docs
]
# Join the documents together to put them in the prompt.
inputs = {
k: v
for k, v in other_keys.items()
if k in self.llm_chain.prompt.input_variables
}
inputs[self.document_variable_name] = self.document_separator.join(
doc_strings
output, extra_return_dict = self.combine_docs(
docs, callbacks=_run_manager.get_child(), **other_keys
)
# Call predict on the LLM.
output, extra_return_dict = (
self.llm_chain(inputs, callbacks=_run_manager.get_child())[
self.llm_chain.output_key
],
{},
)
extra_return_dict[self.output_key] = output
return extra_return_dict
from pydantic import BaseModel
class Generation(Serializable):
"""Output of a single generation."""
text: str
"""Generated text output."""
generation_info: Optional[Dict[str, Any]] = None
"""Raw generation info response from the provider"""
"""May include things like reason for finishing (e.g. in OpenAI)"""
# TODO: add log probs
VALID_TASKS = ("text2text-generation", "text-generation", "summarization")
class LLMChain(Chain):
"""Chain to run queries against LLMs.
@@ -600,13 +153,21 @@ class LLMChain(Chain):
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, str]:
prompts, stop = self.prep_prompts([inputs], run_manager=run_manager)
response = self.llm.generate_prompt(
response = self.generate([inputs], run_manager=run_manager)
return self.create_outputs(response)[0]
def generate(
self,
input_list: List[Dict[str, Any]],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> LLMResult:
"""Generate LLM result from inputs."""
prompts, stop = self.prep_prompts(input_list, run_manager=run_manager)
return self.llm.generate_prompt(
prompts,
stop,
callbacks=run_manager.get_child() if run_manager else None,
)
return self.create_outputs(response)[0]
def prep_prompts(
self,
@@ -662,6 +223,23 @@ class LLMChain(Chain):
for generation in response.generations
]
def predict(self, callbacks: Callbacks = None, **kwargs: Any) -> str:
"""Format prompt with kwargs and pass to LLM.
Args:
callbacks: Callbacks to pass to LLMChain
**kwargs: Keys to pass to prompt template.
Returns:
Completion from LLM.
Example:
.. code-block:: python
completion = llm.predict(adjective="funny")
"""
return self(kwargs, callbacks=callbacks)[self.output_key]
def predict_and_parse(
self, callbacks: Callbacks = None, **kwargs: Any
) -> Union[str, List[str], Dict[str, Any]]:
@@ -772,6 +350,14 @@ class StuffDocumentsChain(BaseCombineDocumentsChain):
prompt = self.llm_chain.prompt.format(**inputs)
return self.llm_chain.llm.get_num_tokens(prompt)
def combine_docs(
self, docs: List[Document], callbacks: Callbacks = None, **kwargs: Any
) -> Tuple[str, dict]:
"""Stuff all documents into one prompt and pass to LLM."""
inputs = self._get_inputs(docs, **kwargs)
# Call predict on the LLM.
return self.llm_chain.predict(callbacks=callbacks, **inputs), {}
@property
def _chain_type(self) -> str:
return "stuff_documents_chain"

View File

@@ -1129,7 +1129,7 @@ class Langchain:
max_time=max_time,
num_return_sequences=num_return_sequences,
)
out = run_qa_db(
for r in run_qa_db(
query=instruction,
iinput=iinput,
context=context,
@@ -1170,8 +1170,689 @@ class Langchain:
auto_reduce_chunks=auto_reduce_chunks,
max_chunks=max_chunks,
device=self.device,
):
(
outr,
extra,
) = r # doesn't accumulate, new answer every yield, so only save that full answer
yield dict(response=outr, sources=extra)
if save_dir:
extra_dict = gen_hyper_langchain.copy()
extra_dict.update(
prompt_type=prompt_type,
inference_server=inference_server,
langchain_mode=langchain_mode,
langchain_action=langchain_action,
document_choice=document_choice,
num_prompt_tokens=num_prompt_tokens,
instruction=instruction,
iinput=iinput,
context=context,
)
save_generate_output(
prompt=prompt,
output=outr,
base_model=base_model,
save_dir=save_dir,
where_from="run_qa_db",
extra_dict=extra_dict,
)
if verbose:
print(
"Post-Generate Langchain: %s decoded_output: %s"
% (str(datetime.now()), len(outr) if outr else -1),
flush=True,
)
if outr or base_model in non_hf_types:
# if got no response (e.g. not showing sources and got no sources,
# so nothing to give to LLM), then slip through and ask LLM
# Or if llama/gptj, then just return since they had no response and can't go down below code path
# clear before return, since .then() never done if from API
clear_torch_cache()
return
if inference_server.startswith(
"openai"
) or inference_server.startswith("http"):
if inference_server.startswith("openai"):
import openai
where_from = "openai_client"
openai.api_key = os.getenv("OPENAI_API_KEY")
stop_sequences = list(
set(prompter.terminate_response + [prompter.PreResponse])
)
stop_sequences = [x for x in stop_sequences if x]
# OpenAI will complain if ask for too many new tokens, takes it as min in some sense, wrongly so.
max_new_tokens_openai = min(
max_new_tokens, model_max_length - num_prompt_tokens
)
gen_server_kwargs = dict(
temperature=temperature if do_sample else 0,
max_tokens=max_new_tokens_openai,
top_p=top_p if do_sample else 1,
frequency_penalty=0,
n=num_return_sequences,
presence_penalty=1.07
- repetition_penalty
+ 0.6, # so good default
)
if inference_server == "openai":
response = openai.Completion.create(
model=base_model,
prompt=prompt,
**gen_server_kwargs,
stop=stop_sequences,
stream=stream_output,
)
if not stream_output:
text = response["choices"][0]["text"]
yield dict(
response=prompter.get_response(
prompt + text,
prompt=prompt,
sanitize_bot_response=sanitize_bot_response,
),
sources="",
)
else:
collected_events = []
text = ""
for event in response:
collected_events.append(
event
) # save the event response
event_text = event["choices"][0][
"text"
] # extract the text
text += event_text # append the text
yield dict(
response=prompter.get_response(
prompt + text,
prompt=prompt,
sanitize_bot_response=sanitize_bot_response,
),
sources="",
)
elif inference_server == "openai_chat":
response = openai.ChatCompletion.create(
model=base_model,
messages=[
{
"role": "system",
"content": "You are a helpful assistant.",
},
{
"role": "user",
"content": prompt,
},
],
stream=stream_output,
**gen_server_kwargs,
)
if not stream_output:
text = response["choices"][0]["message"]["content"]
yield dict(
response=prompter.get_response(
prompt + text,
prompt=prompt,
sanitize_bot_response=sanitize_bot_response,
),
sources="",
)
else:
text = ""
for chunk in response:
delta = chunk["choices"][0]["delta"]
if "content" in delta:
text += delta["content"]
yield dict(
response=prompter.get_response(
prompt + text,
prompt=prompt,
sanitize_bot_response=sanitize_bot_response,
),
sources="",
)
else:
raise RuntimeError(
"No such OpenAI mode: %s" % inference_server
)
elif inference_server.startswith("http"):
inference_server, headers = get_hf_server(inference_server)
from gradio_utils.grclient import GradioClient
from text_generation import Client as HFClient
if isinstance(model, GradioClient):
gr_client = model
hf_client = None
elif isinstance(model, HFClient):
gr_client = None
hf_client = model
else:
(
inference_server,
gr_client,
hf_client,
) = self.get_client_from_inference_server(
inference_server, base_model=base_model
)
# quick sanity check to avoid long timeouts, just see if can reach server
requests.get(
inference_server,
timeout=int(os.getenv("REQUEST_TIMEOUT_FAST", "10")),
)
if gr_client is not None:
# Note: h2oGPT gradio server could handle input token size issues for prompt,
# but best to handle here so send less data to server
chat_client = False
where_from = "gr_client"
client_langchain_mode = "Disabled"
client_langchain_action = LangChainAction.QUERY.value
gen_server_kwargs = dict(
temperature=temperature,
top_p=top_p,
top_k=top_k,
num_beams=num_beams,
max_new_tokens=max_new_tokens,
min_new_tokens=min_new_tokens,
early_stopping=early_stopping,
max_time=max_time,
repetition_penalty=repetition_penalty,
num_return_sequences=num_return_sequences,
do_sample=do_sample,
chat=chat_client,
)
# account for gradio into gradio that handles prompting, avoid duplicating prompter prompt injection
if prompt_type in [
None,
"",
PromptType.plain.name,
PromptType.plain.value,
str(PromptType.plain.value),
]:
# if our prompt is plain, assume either correct or gradio server knows different prompt type,
# so pass empty prompt_Type
gr_prompt_type = ""
gr_prompt_dict = ""
gr_prompt = prompt # already prepared prompt
gr_context = ""
gr_iinput = ""
else:
# if already have prompt_type that is not plain, None, or '', then already applied some prompting
# But assume server can handle prompting, and need to avoid double-up.
# Also assume server can do better job of using stopping.py to stop early, so avoid local prompting, let server handle
# So avoid "prompt" and let gradio server reconstruct from prompt_type we passed
# Note it's ok that prompter.get_response() has prompt+text, prompt=prompt passed,
# because just means extra processing and removal of prompt, but that has no human-bot prompting doesn't matter
# since those won't appear
gr_context = context
gr_prompt = instruction
gr_iinput = iinput
gr_prompt_type = prompt_type
gr_prompt_dict = prompt_dict
client_kwargs = dict(
instruction=gr_prompt
if chat_client
else "", # only for chat=True
iinput=gr_iinput, # only for chat=True
context=gr_context,
# streaming output is supported, loops over and outputs each generation in streaming mode
# but leave stream_output=False for simple input/output mode
stream_output=stream_output,
**gen_server_kwargs,
prompt_type=gr_prompt_type,
prompt_dict=gr_prompt_dict,
instruction_nochat=gr_prompt
if not chat_client
else "",
iinput_nochat=gr_iinput, # only for chat=False
langchain_mode=client_langchain_mode,
langchain_action=client_langchain_action,
top_k_docs=top_k_docs,
chunk=chunk,
chunk_size=chunk_size,
document_choice=[DocumentChoices.All_Relevant.name],
)
api_name = "/submit_nochat_api" # NOTE: like submit_nochat but stable API for string dict passing
if not stream_output:
res = gr_client.predict(
str(dict(client_kwargs)), api_name=api_name
)
res_dict = ast.literal_eval(res)
text = res_dict["response"]
sources = res_dict["sources"]
yield dict(
response=prompter.get_response(
prompt + text,
prompt=prompt,
sanitize_bot_response=sanitize_bot_response,
),
sources=sources,
)
else:
job = gr_client.submit(
str(dict(client_kwargs)), api_name=api_name
)
text = ""
sources = ""
res_dict = dict(response=text, sources=sources)
while not job.done():
outputs_list = job.communicator.job.outputs
if outputs_list:
res = job.communicator.job.outputs[-1]
res_dict = ast.literal_eval(res)
text = res_dict["response"]
sources = res_dict["sources"]
if gr_prompt_type == "plain":
# then gradio server passes back full prompt + text
prompt_and_text = text
else:
prompt_and_text = prompt + text
yield dict(
response=prompter.get_response(
prompt_and_text,
prompt=prompt,
sanitize_bot_response=sanitize_bot_response,
),
sources=sources,
)
time.sleep(0.01)
# ensure get last output to avoid race
res_all = job.outputs()
if len(res_all) > 0:
res = res_all[-1]
res_dict = ast.literal_eval(res)
text = res_dict["response"]
sources = res_dict["sources"]
else:
# go with old text if last call didn't work
e = job.future._exception
if e is not None:
stre = str(e)
strex = "".join(
traceback.format_tb(e.__traceback__)
)
else:
stre = ""
strex = ""
print(
"Bad final response: %s %s %s %s %s: %s %s"
% (
base_model,
inference_server,
res_all,
prompt,
text,
stre,
strex,
),
flush=True,
)
if gr_prompt_type == "plain":
# then gradio server passes back full prompt + text
prompt_and_text = text
else:
prompt_and_text = prompt + text
yield dict(
response=prompter.get_response(
prompt_and_text,
prompt=prompt,
sanitize_bot_response=sanitize_bot_response,
),
sources=sources,
)
elif hf_client:
# HF inference server needs control over input tokens
where_from = "hf_client"
# prompt must include all human-bot like tokens, already added by prompt
# https://github.com/huggingface/text-generation-inference/tree/main/clients/python#types
stop_sequences = list(
set(
prompter.terminate_response
+ [prompter.PreResponse]
)
)
stop_sequences = [x for x in stop_sequences if x]
gen_server_kwargs = dict(
do_sample=do_sample,
max_new_tokens=max_new_tokens,
# best_of=None,
repetition_penalty=repetition_penalty,
return_full_text=True,
seed=SEED,
stop_sequences=stop_sequences,
temperature=temperature,
top_k=top_k,
top_p=top_p,
# truncate=False, # behaves oddly
# typical_p=top_p,
# watermark=False,
# decoder_input_details=False,
)
# work-around for timeout at constructor time, will be issue if multi-threading,
# so just do something reasonable or max_time if larger
# lower bound because client is re-used if multi-threading
hf_client.timeout = max(300, max_time)
if not stream_output:
text = hf_client.generate(
prompt, **gen_server_kwargs
).generated_text
yield dict(
response=prompter.get_response(
text,
prompt=prompt,
sanitize_bot_response=sanitize_bot_response,
),
sources="",
)
else:
text = ""
for response in hf_client.generate_stream(
prompt, **gen_server_kwargs
):
if not response.token.special:
# stop_sequences
text_chunk = response.token.text
text += text_chunk
yield dict(
response=prompter.get_response(
prompt + text,
prompt=prompt,
sanitize_bot_response=sanitize_bot_response,
),
sources="",
)
else:
raise RuntimeError(
"Failed to get client: %s" % inference_server
)
else:
raise RuntimeError(
"No such inference_server %s" % inference_server
)
if save_dir and text:
# save prompt + new text
extra_dict = gen_server_kwargs.copy()
extra_dict.update(
dict(
inference_server=inference_server,
num_prompt_tokens=num_prompt_tokens,
)
)
save_generate_output(
prompt=prompt,
output=text,
base_model=base_model,
save_dir=save_dir,
where_from=where_from,
extra_dict=extra_dict,
)
return
else:
assert not inference_server, (
"inferene_server=%s not supported" % inference_server
)
return out
if isinstance(tokenizer, str):
# pipeline
if tokenizer == "summarization":
key = "summary_text"
else:
raise RuntimeError("No such task type %s" % tokenizer)
# NOTE: uses max_length only
yield dict(
response=model(prompt, max_length=max_new_tokens)[0][key],
sources="",
)
if "mbart-" in base_model.lower():
assert src_lang is not None
tokenizer.src_lang = self.languages_covered()[src_lang]
stopping_criteria = get_stopping(
prompt_type,
prompt_dict,
tokenizer,
self.device,
model_max_length=tokenizer.model_max_length,
)
print(prompt)
# exit(0)
inputs = tokenizer(prompt, return_tensors="pt")
if debug and len(inputs["input_ids"]) > 0:
print("input_ids length", len(inputs["input_ids"][0]), flush=True)
input_ids = inputs["input_ids"].to(self.device)
# CRITICAL LIMIT else will fail
max_max_tokens = tokenizer.model_max_length
max_input_tokens = max_max_tokens - min_new_tokens
# NOTE: Don't limit up front due to max_new_tokens, let go up to max or reach max_max_tokens in stopping.py
input_ids = input_ids[:, -max_input_tokens:]
# required for falcon if multiple threads or asyncio accesses to model during generation
if use_cache is None:
use_cache = False if "falcon" in base_model else True
gen_config_kwargs = dict(
temperature=float(temperature),
top_p=float(top_p),
top_k=top_k,
num_beams=num_beams,
do_sample=do_sample,
repetition_penalty=float(repetition_penalty),
num_return_sequences=num_return_sequences,
renormalize_logits=True,
remove_invalid_values=True,
use_cache=use_cache,
)
token_ids = [
"eos_token_id",
"pad_token_id",
"bos_token_id",
"cls_token_id",
"sep_token_id",
]
for token_id in token_ids:
if (
hasattr(tokenizer, token_id)
and getattr(tokenizer, token_id) is not None
):
gen_config_kwargs.update(
{token_id: getattr(tokenizer, token_id)}
)
generation_config = GenerationConfig(**gen_config_kwargs)
gen_kwargs = dict(
input_ids=input_ids,
generation_config=generation_config,
return_dict_in_generate=True,
output_scores=True,
max_new_tokens=max_new_tokens, # prompt + new
min_new_tokens=min_new_tokens, # prompt + new
early_stopping=early_stopping, # False, True, "never"
max_time=max_time,
stopping_criteria=stopping_criteria,
)
if "gpt2" in base_model.lower():
gen_kwargs.update(
dict(
bos_token_id=tokenizer.bos_token_id,
pad_token_id=tokenizer.eos_token_id,
)
)
elif "mbart-" in base_model.lower():
assert tgt_lang is not None
tgt_lang = self.languages_covered()[tgt_lang]
gen_kwargs.update(
dict(forced_bos_token_id=tokenizer.lang_code_to_id[tgt_lang])
)
else:
token_ids = ["eos_token_id", "bos_token_id", "pad_token_id"]
for token_id in token_ids:
if (
hasattr(tokenizer, token_id)
and getattr(tokenizer, token_id) is not None
):
gen_kwargs.update({token_id: getattr(tokenizer, token_id)})
decoder_kwargs = dict(
skip_special_tokens=True, clean_up_tokenization_spaces=True
)
decoder = functools.partial(tokenizer.decode, **decoder_kwargs)
decoder_raw_kwargs = dict(
skip_special_tokens=False, clean_up_tokenization_spaces=True
)
decoder_raw = functools.partial(tokenizer.decode, **decoder_raw_kwargs)
with torch.no_grad():
have_lora_weights = lora_weights not in [no_lora_str, "", None]
context_class_cast = (
NullContext
if self.device == "cpu"
or have_lora_weights
or self.device == "mps"
else torch.autocast
)
with context_class_cast(self.device):
# protection for gradio not keeping track of closed users,
# else hit bitsandbytes lack of thread safety:
# https://github.com/h2oai/h2ogpt/issues/104
# but only makes sense if concurrency_count == 1
context_class = NullContext # if concurrency_count > 1 else filelock.FileLock
if verbose:
print("Pre-Generate: %s" % str(datetime.now()), flush=True)
decoded_output = None
with context_class("generate.lock"):
if verbose:
print("Generate: %s" % str(datetime.now()), flush=True)
# decoded tokenized prompt can deviate from prompt due to special characters
inputs_decoded = decoder(input_ids[0])
inputs_decoded_raw = decoder_raw(input_ids[0])
if inputs_decoded == prompt:
# normal
pass
elif inputs_decoded.lstrip() == prompt.lstrip():
# sometimes extra space in front, make prompt same for prompt removal
prompt = inputs_decoded
elif inputs_decoded_raw == prompt:
# some models specify special tokens that are part of normal prompt, so can't skip them
inputs_decoded = prompt = inputs_decoded_raw
decoder = decoder_raw
decoder_kwargs = decoder_raw_kwargs
elif inputs_decoded_raw.replace("<unk> ", "").replace(
"<unk>", ""
).replace("\n", " ").replace(" ", "") == prompt.replace(
"\n", " "
).replace(
" ", ""
):
inputs_decoded = prompt = inputs_decoded_raw
decoder = decoder_raw
decoder_kwargs = decoder_raw_kwargs
else:
if verbose:
print(
"WARNING: Special characters in prompt",
flush=True,
)
if stream_output:
skip_prompt = False
streamer = H2OTextIteratorStreamer(
tokenizer,
skip_prompt=skip_prompt,
block=False,
**decoder_kwargs,
)
gen_kwargs.update(dict(streamer=streamer))
target = wrapped_partial(
self.generate_with_exceptions,
model.generate,
prompt=prompt,
inputs_decoded=inputs_decoded,
raise_generate_gpu_exceptions=raise_generate_gpu_exceptions,
**gen_kwargs,
)
bucket = queue.Queue()
thread = EThread(
target=target, streamer=streamer, bucket=bucket
)
thread.start()
outputs = ""
try:
for new_text in streamer:
if bucket.qsize() > 0 or thread.exc:
thread.join()
outputs += new_text
yield dict(
response=prompter.get_response(
outputs,
prompt=inputs_decoded,
sanitize_bot_response=sanitize_bot_response,
),
sources="",
)
except BaseException:
# if any exception, raise that exception if was from thread, first
if thread.exc:
raise thread.exc
raise
finally:
# clear before return, since .then() never done if from API
clear_torch_cache()
# in case no exception and didn't join with thread yet, then join
if not thread.exc:
thread.join()
# in case raise StopIteration or broke queue loop in streamer, but still have exception
if thread.exc:
raise thread.exc
decoded_output = outputs
else:
try:
outputs = model.generate(**gen_kwargs)
finally:
clear_torch_cache() # has to be here for API submit_nochat_api since.then() not called
outputs = [decoder(s) for s in outputs.sequences]
yield dict(
response=prompter.get_response(
outputs,
prompt=inputs_decoded,
sanitize_bot_response=sanitize_bot_response,
),
sources="",
)
if outputs and len(outputs) >= 1:
decoded_output = prompt + outputs[0]
if save_dir and decoded_output:
extra_dict = gen_config_kwargs.copy()
extra_dict.update(
dict(num_prompt_tokens=num_prompt_tokens)
)
save_generate_output(
prompt=prompt,
output=decoded_output,
base_model=base_model,
save_dir=save_dir,
where_from="evaluate_%s" % str(stream_output),
extra_dict=gen_config_kwargs,
)
if verbose:
print(
"Post-Generate: %s decoded_output: %s"
% (
str(datetime.now()),
len(decoded_output) if decoded_output else -1,
),
flush=True,
)
return outputs[0]
inputs_list_names = list(inspect.signature(evaluate).parameters)
global inputs_kwargs_list

View File

@@ -436,7 +436,7 @@ class GradioInference(LLM):
chat_client: bool = False
return_full_text: bool = True
stream_output: bool = Field(False, alias="stream")
stream: bool = False
sanitize_bot_response: bool = False
prompter: Any = None
@@ -481,7 +481,7 @@ class GradioInference(LLM):
# so server should get prompt_type or '', not plain
# This is good, so gradio server can also handle stopping.py conditions
# this is different than TGI server that uses prompter to inject prompt_type prompting
stream_output = self.stream_output
stream_output = self.stream
gr_client = self.client
client_langchain_mode = "Disabled"
client_langchain_action = LangChainAction.QUERY.value
@@ -596,7 +596,7 @@ class H2OHuggingFaceTextGenInference(HuggingFaceTextGenInference):
inference_server_url: str = ""
timeout: int = 300
headers: dict = None
stream_output: bool = Field(False, alias="stream")
stream: bool = False
sanitize_bot_response: bool = False
prompter: Any = None
tokenizer: Any = None
@@ -663,7 +663,7 @@ class H2OHuggingFaceTextGenInference(HuggingFaceTextGenInference):
# lower bound because client is re-used if multi-threading
self.client.timeout = max(300, self.timeout)
if not self.stream_output:
if not self.stream:
res = self.client.generate(
prompt,
**gen_server_kwargs,
@@ -852,7 +852,7 @@ def get_llm(
top_p=top_p,
# typical_p=top_p,
callbacks=callbacks if stream_output else None,
stream_output=stream_output,
stream=stream_output,
prompter=prompter,
tokenizer=tokenizer,
client=hf_client,
@@ -2510,7 +2510,8 @@ def _run_qa_db(
formatted_doc_chunks = "\n\n".join(
[get_url(x) + "\n\n" + x.page_content for x in docs]
)
return formatted_doc_chunks, ""
yield formatted_doc_chunks, ""
return
if not docs and langchain_action in [
LangChainAction.SUMMARIZE_MAP.value,
LangChainAction.SUMMARIZE_ALL.value,
@@ -2522,7 +2523,8 @@ def _run_qa_db(
else "No documents to summarize."
)
extra = ""
return ret, extra
yield ret, extra
return
if not docs and langchain_mode not in [
LangChainMode.DISABLED.value,
LangChainMode.CHAT_LLM.value,
@@ -2534,7 +2536,8 @@ def _run_qa_db(
else "No documents to query."
)
extra = ""
return ret, extra
yield ret, extra
return
if chain is None and model_name not in non_hf_types:
# here if no docs at all and not HF type
@@ -2554,7 +2557,22 @@ def _run_qa_db(
)
with context_class_cast(args.device):
answer = chain()
return answer
if not use_context:
ret = answer["output_text"]
extra = ""
yield ret, extra
elif answer is not None:
ret, extra = get_sources_answer(
query,
answer,
scores,
show_rank,
answer_with_sources,
verbose=verbose,
)
yield ret, extra
return
def get_similarity_chain(

View File

@@ -3,11 +3,13 @@ from apps.stable_diffusion.src.utils.utils import _compile_module
from io import BytesIO
import torch_mlir
from transformers import TextGenerationPipeline
from transformers.pipelines.text_generation import ReturnType
from stopping import get_stopping
from prompter import Prompter, PromptType
from transformers import TextGenerationPipeline
from transformers.pipelines.text_generation import ReturnType
from transformers.generation import (
GenerationConfig,
LogitsProcessorList,
@@ -20,7 +22,7 @@ import gc
from pathlib import Path
from shark.shark_inference import SharkInference
from shark.shark_downloader import download_public_file
from shark.shark_importer import import_with_fx, save_mlir
from shark.shark_importer import import_with_fx
from apps.stable_diffusion.src import args
# Brevitas
@@ -29,8 +31,14 @@ from brevitas_examples.llm.llm_quant.quantize import quantize_model
from brevitas_examples.llm.llm_quant.run_utils import get_model_impl
# fmt: off
def quantmatmul_rhs_group_quant〡shape(lhs: List[int], rhs: List[int], rhs_scale: List[int], rhs_zero_point: List[int], rhs_bit_width: int, rhs_group_size: int) -> List[int]:
def brevitasmatmul_rhs_group_quant〡shape(
lhs: List[int],
rhs: List[int],
rhs_scale: List[int],
rhs_zero_point: List[int],
rhs_bit_width: int,
rhs_group_size: int,
) -> List[int]:
if len(lhs) == 3 and len(rhs) == 2:
return [lhs[0], lhs[1], rhs[0]]
elif len(lhs) == 2 and len(rhs) == 2:
@@ -39,21 +47,30 @@ def quantmatmul_rhs_group_quant〡shape(lhs: List[int], rhs: List[int], rhs_s
raise ValueError("Input shapes not supported.")
def quantmatmul_rhs_group_quant〡dtype(lhs_rank_dtype: Tuple[int, int], rhs_rank_dtype: Tuple[int, int], rhs_scale_rank_dtype: Tuple[int, int], rhs_zero_point_rank_dtype: Tuple[int, int], rhs_bit_width: int, rhs_group_size: int) -> int:
def brevitasmatmul_rhs_group_quant〡dtype(
lhs_rank_dtype: Tuple[int, int],
rhs_rank_dtype: Tuple[int, int],
rhs_scale_rank_dtype: Tuple[int, int],
rhs_zero_point_rank_dtype: Tuple[int, int],
rhs_bit_width: int,
rhs_group_size: int,
) -> int:
# output dtype is the dtype of the lhs float input
lhs_rank, lhs_dtype = lhs_rank_dtype
return lhs_dtype
def quantmatmul_rhs_group_quant〡has_value_semantics(lhs, rhs, rhs_scale, rhs_zero_point, rhs_bit_width, rhs_group_size) -> None:
def brevitasmatmul_rhs_group_quant〡has_value_semantics(
lhs, rhs, rhs_scale, rhs_zero_point, rhs_bit_width, rhs_group_size
) -> None:
return
brevitas_matmul_rhs_group_quant_library = [
quantmatmul_rhs_group_quant〡shape,
quantmatmul_rhs_group_quant〡dtype,
quantmatmul_rhs_group_quant〡has_value_semantics]
# fmt: on
brevitasmatmul_rhs_group_quant〡shape,
brevitasmatmul_rhs_group_quant〡dtype,
brevitasmatmul_rhs_group_quant〡has_value_semantics,
]
global_device = "cuda"
global_precision = "fp16"
@@ -229,7 +246,7 @@ class H2OGPTSHARKModel(torch.nn.Module):
ts_graph,
[*h2ogptCompileInput],
output_type=torch_mlir.OutputType.TORCH,
backend_legal_ops=["quant.matmul_rhs_group_quant"],
backend_legal_ops=["brevitas.matmul_rhs_group_quant"],
extra_library=brevitas_matmul_rhs_group_quant_library,
use_tracing=False,
verbose=False,
@@ -237,7 +254,7 @@ class H2OGPTSHARKModel(torch.nn.Module):
print(f"[DEBUG] converting torch to linalg")
run_pipeline_with_repro_report(
module,
"builtin.module(func.func(torch-unpack-quant-tensor),func.func(torch-convert-custom-quant-op),torch-backend-to-linalg-on-tensors-backend-pipeline)",
"builtin.module(func.func(torch-unpack-torch-tensor),torch-backend-to-linalg-on-tensors-backend-pipeline)",
description="Lowering Torch Backend IR -> Linalg-on-Tensors Backend IR",
)
else:
@@ -256,11 +273,6 @@ class H2OGPTSHARKModel(torch.nn.Module):
bytecode = bytecode_stream.getvalue()
del module
bytecode = save_mlir(
bytecode,
model_name=f"h2ogpt_{precision}",
frontend="torch",
)
return bytecode
def forward(self, input_ids, attention_mask):
@@ -273,215 +285,7 @@ class H2OGPTSHARKModel(torch.nn.Module):
return result
def decode_tokens(tokenizer, res_tokens):
for i in range(len(res_tokens)):
if type(res_tokens[i]) != int:
res_tokens[i] = int(res_tokens[i][0])
res_str = tokenizer.decode(res_tokens, skip_special_tokens=True)
return res_str
def generate_token(h2ogpt_shark_model, model, tokenizer, **generate_kwargs):
del generate_kwargs["max_time"]
generate_kwargs["input_ids"] = generate_kwargs["input_ids"].to(
device=tensor_device
)
generate_kwargs["attention_mask"] = generate_kwargs["attention_mask"].to(
device=tensor_device
)
truncated_input_ids = []
stopping_criteria = generate_kwargs["stopping_criteria"]
generation_config_ = GenerationConfig.from_model_config(model.config)
generation_config = copy.deepcopy(generation_config_)
model_kwargs = generation_config.update(**generate_kwargs)
logits_processor = LogitsProcessorList()
stopping_criteria = (
stopping_criteria
if stopping_criteria is not None
else StoppingCriteriaList()
)
eos_token_id = generation_config.eos_token_id
generation_config.pad_token_id = eos_token_id
(
inputs_tensor,
model_input_name,
model_kwargs,
) = model._prepare_model_inputs(
None, generation_config.bos_token_id, model_kwargs
)
model_kwargs["output_attentions"] = generation_config.output_attentions
model_kwargs[
"output_hidden_states"
] = generation_config.output_hidden_states
model_kwargs["use_cache"] = generation_config.use_cache
input_ids = (
inputs_tensor
if model_input_name == "input_ids"
else model_kwargs.pop("input_ids")
)
input_ids_seq_length = input_ids.shape[-1]
generation_config.max_length = (
generation_config.max_new_tokens + input_ids_seq_length
)
logits_processor = model._get_logits_processor(
generation_config=generation_config,
input_ids_seq_length=input_ids_seq_length,
encoder_input_ids=inputs_tensor,
prefix_allowed_tokens_fn=None,
logits_processor=logits_processor,
)
stopping_criteria = model._get_stopping_criteria(
generation_config=generation_config,
stopping_criteria=stopping_criteria,
)
logits_warper = model._get_logits_warper(generation_config)
(
input_ids,
model_kwargs,
) = model._expand_inputs_for_generation(
input_ids=input_ids,
expand_size=generation_config.num_return_sequences, # 1
is_encoder_decoder=model.config.is_encoder_decoder, # False
**model_kwargs,
)
if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
eos_token_id_tensor = (
torch.tensor(eos_token_id).to(device=tensor_device)
if eos_token_id is not None
else None
)
pad_token_id = generation_config.pad_token_id
eos_token_id = eos_token_id
output_scores = generation_config.output_scores # False
return_dict_in_generate = (
generation_config.return_dict_in_generate # False
)
# init attention / hidden states / scores tuples
scores = () if (return_dict_in_generate and output_scores) else None
# keep track of which sequences are already finished
unfinished_sequences = torch.ones(
input_ids.shape[0],
dtype=torch.long,
device=input_ids.device,
)
timesRan = 0
import time
start = time.time()
print("\n")
res_tokens = []
while True:
model_inputs = model.prepare_inputs_for_generation(
input_ids, **model_kwargs
)
outputs = h2ogpt_shark_model.forward(
model_inputs["input_ids"], model_inputs["attention_mask"]
)
if args.precision == "fp16":
outputs = outputs.to(dtype=torch.float32)
next_token_logits = outputs
# pre-process distribution
next_token_scores = logits_processor(input_ids, next_token_logits)
next_token_scores = logits_warper(input_ids, next_token_scores)
# sample
probs = torch.nn.functional.softmax(next_token_scores, dim=-1)
next_token = torch.multinomial(probs, num_samples=1).squeeze(1)
# finished sentences should have their next token be a padding token
if eos_token_id is not None:
if pad_token_id is None:
raise ValueError(
"If `eos_token_id` is defined, make sure that `pad_token_id` is defined."
)
next_token = next_token * unfinished_sequences + pad_token_id * (
1 - unfinished_sequences
)
input_ids = torch.cat([input_ids, next_token[:, None]], dim=-1)
model_kwargs["past_key_values"] = None
if "attention_mask" in model_kwargs:
attention_mask = model_kwargs["attention_mask"]
model_kwargs["attention_mask"] = torch.cat(
[
attention_mask,
attention_mask.new_ones((attention_mask.shape[0], 1)),
],
dim=-1,
)
truncated_input_ids.append(input_ids[:, 0])
input_ids = input_ids[:, 1:]
model_kwargs["attention_mask"] = model_kwargs["attention_mask"][:, 1:]
new_word = tokenizer.decode(
next_token.cpu().numpy(),
add_special_tokens=False,
skip_special_tokens=True,
clean_up_tokenization_spaces=True,
)
res_tokens.append(next_token)
if new_word == "<0x0A>":
print("\n", end="", flush=True)
else:
print(f"{new_word}", end=" ", flush=True)
part_str = decode_tokens(tokenizer, res_tokens)
yield part_str
# if eos_token was found in one sentence, set sentence to finished
if eos_token_id_tensor is not None:
unfinished_sequences = unfinished_sequences.mul(
next_token.tile(eos_token_id_tensor.shape[0], 1)
.ne(eos_token_id_tensor.unsqueeze(1))
.prod(dim=0)
)
# stop when each sentence is finished
if unfinished_sequences.max() == 0 or stopping_criteria(
input_ids, scores
):
break
timesRan = timesRan + 1
end = time.time()
print(
"\n\nTime taken is {:.2f} seconds/token\n".format(
(end - start) / timesRan
)
)
torch.cuda.empty_cache()
gc.collect()
res_str = decode_tokens(tokenizer, res_tokens)
yield res_str
h2ogpt_model = H2OGPTSHARKModel()
def pad_or_truncate_inputs(
@@ -694,6 +498,233 @@ class H2OTextGenerationPipeline(TextGenerationPipeline):
)
return records
def generate_new_token(self):
model_inputs = self.model.prepare_inputs_for_generation(
self.input_ids, **self.model_kwargs
)
outputs = h2ogpt_model.forward(
model_inputs["input_ids"], model_inputs["attention_mask"]
)
if args.precision == "fp16":
outputs = outputs.to(dtype=torch.float32)
next_token_logits = outputs
# pre-process distribution
next_token_scores = self.logits_processor(
self.input_ids, next_token_logits
)
next_token_scores = self.logits_warper(
self.input_ids, next_token_scores
)
# sample
probs = torch.nn.functional.softmax(next_token_scores, dim=-1)
next_token = torch.multinomial(probs, num_samples=1).squeeze(1)
# finished sentences should have their next token be a padding token
if self.eos_token_id is not None:
if self.pad_token_id is None:
raise ValueError(
"If `eos_token_id` is defined, make sure that `pad_token_id` is defined."
)
next_token = (
next_token * self.unfinished_sequences
+ self.pad_token_id * (1 - self.unfinished_sequences)
)
self.input_ids = torch.cat(
[self.input_ids, next_token[:, None]], dim=-1
)
self.model_kwargs["past_key_values"] = None
if "attention_mask" in self.model_kwargs:
attention_mask = self.model_kwargs["attention_mask"]
self.model_kwargs["attention_mask"] = torch.cat(
[
attention_mask,
attention_mask.new_ones((attention_mask.shape[0], 1)),
],
dim=-1,
)
self.truncated_input_ids.append(self.input_ids[:, 0])
self.input_ids = self.input_ids[:, 1:]
self.model_kwargs["attention_mask"] = self.model_kwargs[
"attention_mask"
][:, 1:]
return next_token
def generate_token(self, **generate_kwargs):
del generate_kwargs["max_time"]
self.truncated_input_ids = []
generation_config_ = GenerationConfig.from_model_config(
self.model.config
)
generation_config = copy.deepcopy(generation_config_)
self.model_kwargs = generation_config.update(**generate_kwargs)
logits_processor = LogitsProcessorList()
self.stopping_criteria = (
self.stopping_criteria
if self.stopping_criteria is not None
else StoppingCriteriaList()
)
eos_token_id = generation_config.eos_token_id
generation_config.pad_token_id = eos_token_id
(
inputs_tensor,
model_input_name,
self.model_kwargs,
) = self.model._prepare_model_inputs(
None, generation_config.bos_token_id, self.model_kwargs
)
batch_size = inputs_tensor.shape[0]
self.model_kwargs[
"output_attentions"
] = generation_config.output_attentions
self.model_kwargs[
"output_hidden_states"
] = generation_config.output_hidden_states
self.model_kwargs["use_cache"] = generation_config.use_cache
self.input_ids = (
inputs_tensor
if model_input_name == "input_ids"
else self.model_kwargs.pop("input_ids")
)
input_ids_seq_length = self.input_ids.shape[-1]
generation_config.max_length = (
generation_config.max_new_tokens + input_ids_seq_length
)
self.logits_processor = self.model._get_logits_processor(
generation_config=generation_config,
input_ids_seq_length=input_ids_seq_length,
encoder_input_ids=inputs_tensor,
prefix_allowed_tokens_fn=None,
logits_processor=logits_processor,
)
self.stopping_criteria = self.model._get_stopping_criteria(
generation_config=generation_config,
stopping_criteria=self.stopping_criteria,
)
self.logits_warper = self.model._get_logits_warper(generation_config)
(
self.input_ids,
self.model_kwargs,
) = self.model._expand_inputs_for_generation(
input_ids=self.input_ids,
expand_size=generation_config.num_return_sequences, # 1
is_encoder_decoder=self.model.config.is_encoder_decoder, # False
**self.model_kwargs,
)
if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
self.eos_token_id_tensor = (
torch.tensor(eos_token_id).to(device=tensor_device)
if eos_token_id is not None
else None
)
self.pad_token_id = generation_config.pad_token_id
self.eos_token_id = eos_token_id
output_scores = generation_config.output_scores # False
output_attentions = generation_config.output_attentions # False
output_hidden_states = generation_config.output_hidden_states # False
return_dict_in_generate = (
generation_config.return_dict_in_generate # False
)
# init attention / hidden states / scores tuples
self.scores = (
() if (return_dict_in_generate and output_scores) else None
)
decoder_attentions = (
() if (return_dict_in_generate and output_attentions) else None
)
cross_attentions = (
() if (return_dict_in_generate and output_attentions) else None
)
decoder_hidden_states = (
() if (return_dict_in_generate and output_hidden_states) else None
)
# keep track of which sequences are already finished
self.unfinished_sequences = torch.ones(
self.input_ids.shape[0],
dtype=torch.long,
device=self.input_ids.device,
)
timesRan = 0
import time
start = time.time()
print("\n")
while True:
next_token = self.generate_new_token()
new_word = self.tokenizer.decode(
next_token.cpu().numpy(),
add_special_tokens=False,
skip_special_tokens=True,
clean_up_tokenization_spaces=True,
)
print(f"{new_word}", end="", flush=True)
# if eos_token was found in one sentence, set sentence to finished
if self.eos_token_id_tensor is not None:
self.unfinished_sequences = self.unfinished_sequences.mul(
next_token.tile(self.eos_token_id_tensor.shape[0], 1)
.ne(self.eos_token_id_tensor.unsqueeze(1))
.prod(dim=0)
)
# stop when each sentence is finished
if (
self.unfinished_sequences.max() == 0
or self.stopping_criteria(self.input_ids, self.scores)
):
break
timesRan = timesRan + 1
end = time.time()
print(
"\n\nTime taken is {:.2f} seconds/token\n".format(
(end - start) / timesRan
)
)
self.input_ids = torch.cat(
[
torch.tensor(self.truncated_input_ids)
.to(device=tensor_device)
.unsqueeze(dim=0),
self.input_ids,
],
dim=-1,
)
torch.cuda.empty_cache()
gc.collect()
return self.input_ids
def _forward(self, model_inputs, **generate_kwargs):
if self.can_stop:
stopping_criteria = get_stopping(
@@ -753,13 +784,19 @@ class H2OTextGenerationPipeline(TextGenerationPipeline):
input_ids, attention_mask = pad_or_truncate_inputs(
input_ids, attention_mask, max_padding_length=max_padding_length
)
self.stopping_criteria = generate_kwargs["stopping_criteria"]
return_dict = {
"model": self.model,
"tokenizer": self.tokenizer,
generated_sequence = self.generate_token(
input_ids=input_ids,
attention_mask=attention_mask,
**generate_kwargs,
)
out_b = generated_sequence.shape[0]
generated_sequence = generated_sequence.reshape(
in_b, out_b // in_b, *generated_sequence.shape[1:]
)
return {
"generated_sequence": generated_sequence,
"input_ids": input_ids,
"attention_mask": attention_mask,
"attention_mask": attention_mask,
"prompt_text": prompt_text,
}
return_dict = {**return_dict, **generate_kwargs}
return return_dict

View File

@@ -1,4 +1,5 @@
import os
import fire
from gpt_langchain import (
path_to_docs,
@@ -201,3 +202,7 @@ def make_db_main(
if verbose:
print("DONE", flush=True)
return db, collection_name
if __name__ == "__main__":
fire.Fire(make_db_main)

View File

@@ -1,442 +0,0 @@
from pathlib import Path
import argparse
from argparse import RawTextHelpFormatter
import re, gc
"""
This script can be used as a standalone utility to convert IRs to dynamic + combine them.
Following are the various ways this script can be used :-
a. To convert a single Linalg IR to dynamic IR:
--dynamic --first_ir_path=<PATH TO FIRST IR>
b. To convert two Linalg IRs to dynamic IR:
--dynamic --first_ir_path=<PATH TO SECOND IR> --first_ir_path=<PATH TO SECOND IR>
c. To combine two Linalg IRs into one:
--combine --first_ir_path=<PATH TO FIRST IR> --second_ir_path=<PATH TO SECOND IR>
d. To convert both IRs into dynamic as well as combine the IRs:
--dynamic --combine --first_ir_path=<PATH TO FIRST IR> --second_ir_path=<PATH TO SECOND IR>
NOTE: For dynamic you'll also need to provide the following set of flags:-
i. For First Llama : --dynamic_input_size (DEFAULT: 19)
ii. For Second Llama: --model_name (DEFAULT: llama2_7b)
--precision (DEFAULT: 'int4')
You may use --save_dynamic to also save the dynamic IR in option d above.
Else for option a. and b. the dynamic IR(s) will get saved by default.
"""
def combine_mlir_scripts(
first_vicuna_mlir,
second_vicuna_mlir,
output_name,
return_ir=True,
):
print(f"[DEBUG] combining first and second mlir")
print(f"[DEBUG] output_name = {output_name}")
maps1 = []
maps2 = []
constants = set()
f1 = []
f2 = []
print(f"[DEBUG] processing first vicuna mlir")
first_vicuna_mlir = first_vicuna_mlir.splitlines()
while first_vicuna_mlir:
line = first_vicuna_mlir.pop(0)
if re.search("#map\d*\s*=", line):
maps1.append(line)
elif re.search("arith.constant", line):
constants.add(line)
elif not re.search("module", line):
line = re.sub("forward", "first_vicuna_forward", line)
f1.append(line)
f1 = f1[:-1]
del first_vicuna_mlir
gc.collect()
for i, map_line in enumerate(maps1):
map_var = map_line.split(" ")[0]
map_line = re.sub(f"{map_var}(?!\d)", map_var + "_0", map_line)
maps1[i] = map_line
f1 = [
re.sub(f"{map_var}(?!\d)", map_var + "_0", func_line)
for func_line in f1
]
print(f"[DEBUG] processing second vicuna mlir")
second_vicuna_mlir = second_vicuna_mlir.splitlines()
while second_vicuna_mlir:
line = second_vicuna_mlir.pop(0)
if re.search("#map\d*\s*=", line):
maps2.append(line)
elif "global_seed" in line:
continue
elif re.search("arith.constant", line):
constants.add(line)
elif not re.search("module", line):
line = re.sub("forward", "second_vicuna_forward", line)
f2.append(line)
f2 = f2[:-1]
del second_vicuna_mlir
gc.collect()
for i, map_line in enumerate(maps2):
map_var = map_line.split(" ")[0]
map_line = re.sub(f"{map_var}(?!\d)", map_var + "_1", map_line)
maps2[i] = map_line
f2 = [
re.sub(f"{map_var}(?!\d)", map_var + "_1", func_line)
for func_line in f2
]
module_start = 'module attributes {torch.debug_module_name = "_lambda"} {'
module_end = "}"
global_vars = []
vnames = []
global_var_loading1 = []
global_var_loading2 = []
print(f"[DEBUG] processing constants")
counter = 0
constants = list(constants)
while constants:
constant = constants.pop(0)
vname, vbody = constant.split("=")
vname = re.sub("%", "", vname)
vname = vname.strip()
vbody = re.sub("arith.constant", "", vbody)
vbody = vbody.strip()
if len(vbody.split(":")) < 2:
print(constant)
vdtype = vbody.split(":")[-1].strip()
fixed_vdtype = vdtype
if "c1_i64" in vname:
print(constant)
counter += 1
if counter == 2:
counter = 0
print("detected duplicate")
continue
vnames.append(vname)
if "true" not in vname:
global_vars.append(
f"ml_program.global private @{vname}({vbody}) : {fixed_vdtype}"
)
global_var_loading1.append(
f"\t\t%{vname} = ml_program.global_load_const @{vname} : {fixed_vdtype}"
)
global_var_loading2.append(
f"\t\t%{vname} = ml_program.global_load_const @{vname} : {fixed_vdtype}"
)
else:
global_vars.append(
f"ml_program.global private @{vname}({vbody}) : i1"
)
global_var_loading1.append(
f"\t\t%{vname} = ml_program.global_load_const @{vname} : i1"
)
global_var_loading2.append(
f"\t\t%{vname} = ml_program.global_load_const @{vname} : i1"
)
new_f1, new_f2 = [], []
print(f"[DEBUG] processing f1")
for line in f1:
if "func.func" in line:
new_f1.append(line)
for global_var in global_var_loading1:
new_f1.append(global_var)
else:
new_f1.append(line)
print(f"[DEBUG] processing f2")
for line in f2:
if "func.func" in line:
new_f2.append(line)
for global_var in global_var_loading2:
if (
"c20_i64 = arith.addi %dim_i64, %c1_i64 : i64"
in global_var
):
print(global_var)
new_f2.append(global_var)
else:
new_f2.append(line)
f1 = new_f1
f2 = new_f2
del new_f1
del new_f2
gc.collect()
print(
[
"c20_i64 = arith.addi %dim_i64, %c1_i64 : i64" in x
for x in [maps1, maps2, global_vars, f1, f2]
]
)
# doing it this way rather than assembling the whole string
# to prevent OOM with 64GiB RAM when encoding the file.
print(f"[DEBUG] Saving mlir to {output_name}")
with open(output_name, "w+") as f_:
f_.writelines(line + "\n" for line in maps1)
f_.writelines(line + "\n" for line in maps2)
f_.writelines(line + "\n" for line in [module_start])
f_.writelines(line + "\n" for line in global_vars)
f_.writelines(line + "\n" for line in f1)
f_.writelines(line + "\n" for line in f2)
f_.writelines(line + "\n" for line in [module_end])
del maps1
del maps2
del module_start
del global_vars
del f1
del f2
del module_end
gc.collect()
if return_ir:
print(f"[DEBUG] Reading combined mlir back in")
with open(output_name, "rb") as f:
return f.read()
def write_in_dynamic_inputs0(module, dynamic_input_size):
print("[DEBUG] writing dynamic inputs to first vicuna")
# Current solution for ensuring mlir files support dynamic inputs
# TODO: find a more elegant way to implement this
new_lines = []
module = module.splitlines()
while module:
line = module.pop(0)
line = re.sub(f"{dynamic_input_size}x", "?x", line)
if "?x" in line:
line = re.sub("tensor.empty\(\)", "tensor.empty(%dim)", line)
line = re.sub(f" {dynamic_input_size},", " %dim,", line)
if "tensor.empty" in line and "?x?" in line:
line = re.sub(
"tensor.empty\(%dim\)", "tensor.empty(%dim, %dim)", line
)
if "arith.cmpi" in line:
line = re.sub(f"c{dynamic_input_size}", "dim", line)
if "%0 = tensor.empty(%dim) : tensor<?xi64>" in line:
new_lines.append("%dim = tensor.dim %arg0, %c1 : tensor<1x?xi64>")
if "%dim = tensor.dim %arg0, %c1 : tensor<1x?xi64>" in line:
continue
new_lines.append(line)
return "\n".join(new_lines)
def write_in_dynamic_inputs1(module, model_name, precision):
print("[DEBUG] writing dynamic inputs to second vicuna")
def remove_constant_dim(line):
if "c19_i64" in line:
line = re.sub("c19_i64", "dim_i64", line)
if "19x" in line:
line = re.sub("19x", "?x", line)
line = re.sub("tensor.empty\(\)", "tensor.empty(%dim)", line)
if "tensor.empty" in line and "?x?" in line:
line = re.sub(
"tensor.empty\(%dim\)",
"tensor.empty(%dim, %dim)",
line,
)
if "arith.cmpi" in line:
line = re.sub("c19", "dim", line)
if " 19," in line:
line = re.sub(" 19,", " %dim,", line)
if "x20x" in line or "<20x" in line:
line = re.sub("20x", "?x", line)
line = re.sub("tensor.empty\(\)", "tensor.empty(%dimp1)", line)
if " 20," in line:
line = re.sub(" 20,", " %dimp1,", line)
return line
module = module.splitlines()
new_lines = []
# Using a while loop and the pop method to avoid creating a copy of module
if "llama2_13b" in model_name:
pkv_tensor_shape = "tensor<1x40x?x128x"
elif "llama2_70b" in model_name:
pkv_tensor_shape = "tensor<1x8x?x128x"
else:
pkv_tensor_shape = "tensor<1x32x?x128x"
if precision in ["fp16", "int4", "int8"]:
pkv_tensor_shape += "f16>"
else:
pkv_tensor_shape += "f32>"
while module:
line = module.pop(0)
if "%c19_i64 = arith.constant 19 : i64" in line:
new_lines.append("%c2 = arith.constant 2 : index")
new_lines.append(
f"%dim_4_int = tensor.dim %arg1, %c2 : {pkv_tensor_shape}"
)
new_lines.append(
"%dim_i64 = arith.index_cast %dim_4_int : index to i64"
)
continue
if "%c2 = arith.constant 2 : index" in line:
continue
if "%c20_i64 = arith.constant 20 : i64" in line:
new_lines.append("%c1_i64 = arith.constant 1 : i64")
new_lines.append("%c20_i64 = arith.addi %dim_i64, %c1_i64 : i64")
new_lines.append(
"%dimp1 = arith.index_cast %c20_i64 : i64 to index"
)
continue
line = remove_constant_dim(line)
new_lines.append(line)
return "\n".join(new_lines)
def save_dynamic_ir(ir_to_save, output_file):
if not ir_to_save:
return
# We only get string output from the dynamic conversion utility.
from contextlib import redirect_stdout
with open(output_file, "w") as f:
with redirect_stdout(f):
print(ir_to_save)
if __name__ == "__main__":
parser = argparse.ArgumentParser(
prog="llama ir utility",
description="\tThis script can be used as a standalone utility to convert IRs to dynamic + combine them.\n"
+ "\tFollowing are the various ways this script can be used :-\n"
+ "\t\ta. To convert a single Linalg IR to dynamic IR:\n"
+ "\t\t\t--dynamic --first_ir_path=<PATH TO FIRST IR>\n"
+ "\t\tb. To convert two Linalg IRs to dynamic IR:\n"
+ "\t\t\t--dynamic --first_ir_path=<PATH TO SECOND IR> --first_ir_path=<PATH TO SECOND IR>\n"
+ "\t\tc. To combine two Linalg IRs into one:\n"
+ "\t\t\t--combine --first_ir_path=<PATH TO FIRST IR> --second_ir_path=<PATH TO SECOND IR>\n"
+ "\t\td. To convert both IRs into dynamic as well as combine the IRs:\n"
+ "\t\t\t--dynamic --combine --first_ir_path=<PATH TO FIRST IR> --second_ir_path=<PATH TO SECOND IR>\n\n"
+ "\tNOTE: For dynamic you'll also need to provide the following set of flags:-\n"
+ "\t\t i. For First Llama : --dynamic_input_size (DEFAULT: 19)\n"
+ "\t\tii. For Second Llama: --model_name (DEFAULT: llama2_7b)\n"
+ "\t\t\t--precision (DEFAULT: 'int4')\n"
+ "\t You may use --save_dynamic to also save the dynamic IR in option d above.\n"
+ "\t Else for option a. and b. the dynamic IR(s) will get saved by default.\n",
formatter_class=RawTextHelpFormatter,
)
parser.add_argument(
"--precision",
"-p",
default="int4",
choices=["fp32", "fp16", "int8", "int4"],
help="Precision of the concerned IR",
)
parser.add_argument(
"--model_name",
type=str,
default="llama2_7b",
choices=["vicuna", "llama2_7b", "llama2_13b", "llama2_70b"],
help="Specify which model to run.",
)
parser.add_argument(
"--first_ir_path",
default=None,
help="path to first llama mlir file",
)
parser.add_argument(
"--second_ir_path",
default=None,
help="path to second llama mlir file",
)
parser.add_argument(
"--dynamic_input_size",
type=int,
default=19,
help="Specify the static input size to replace with dynamic dim.",
)
parser.add_argument(
"--dynamic",
default=False,
action=argparse.BooleanOptionalAction,
help="Converts the IR(s) to dynamic",
)
parser.add_argument(
"--save_dynamic",
default=False,
action=argparse.BooleanOptionalAction,
help="Save the individual IR(s) after converting to dynamic",
)
parser.add_argument(
"--combine",
default=False,
action=argparse.BooleanOptionalAction,
help="Converts the IR(s) to dynamic",
)
args, unknown = parser.parse_known_args()
dynamic = args.dynamic
combine = args.combine
assert (
dynamic or combine
), "neither `dynamic` nor `combine` flag is turned on"
first_ir_path = args.first_ir_path
second_ir_path = args.second_ir_path
assert first_ir_path or second_ir_path, "no input ir has been provided"
if combine:
assert (
first_ir_path and second_ir_path
), "you will need to provide both IRs to combine"
precision = args.precision
model_name = args.model_name
dynamic_input_size = args.dynamic_input_size
save_dynamic = args.save_dynamic
print(f"Dynamic conversion utility is turned {'ON' if dynamic else 'OFF'}")
print(f"Combining IR utility is turned {'ON' if combine else 'OFF'}")
if dynamic and not combine:
save_dynamic = True
first_ir = None
first_dynamic_ir_name = None
second_ir = None
second_dynamic_ir_name = None
if first_ir_path:
first_dynamic_ir_name = f"{Path(first_ir_path).stem}_dynamic"
with open(first_ir_path, "r") as f:
first_ir = f.read()
if second_ir_path:
second_dynamic_ir_name = f"{Path(second_ir_path).stem}_dynamic"
with open(second_ir_path, "r") as f:
second_ir = f.read()
if dynamic:
first_ir = (
write_in_dynamic_inputs0(first_ir, dynamic_input_size)
if first_ir
else None
)
second_ir = (
write_in_dynamic_inputs1(second_ir, model_name, precision)
if second_ir
else None
)
if save_dynamic:
save_dynamic_ir(first_ir, f"{first_dynamic_ir_name}.mlir")
save_dynamic_ir(second_ir, f"{second_dynamic_ir_name}.mlir")
if combine:
combine_mlir_scripts(
first_ir,
second_ir,
f"{model_name}_{precision}.mlir",
return_ir=False,
)

View File

@@ -46,7 +46,6 @@ def compile_stableLM(
model_vmfb_name,
device="cuda",
precision="fp32",
debug=False,
):
from shark.shark_inference import SharkInference
@@ -93,7 +92,7 @@ def compile_stableLM(
shark_module.compile()
path = shark_module.save_module(
vmfb_path.parent.absolute(), vmfb_path.stem, debug=debug
vmfb_path.parent.absolute(), vmfb_path.stem
)
print("Saved vmfb at ", str(path))

File diff suppressed because it is too large Load Diff

View File

@@ -1,94 +0,0 @@
# -*- mode: python ; coding: utf-8 -*-
from PyInstaller.utils.hooks import collect_data_files
from PyInstaller.utils.hooks import collect_submodules
from PyInstaller.utils.hooks import copy_metadata
import sys ; sys.setrecursionlimit(sys.getrecursionlimit() * 5)
datas = []
datas += collect_data_files('torch')
datas += copy_metadata('torch')
datas += copy_metadata('tqdm')
datas += copy_metadata('regex')
datas += copy_metadata('requests')
datas += copy_metadata('packaging')
datas += copy_metadata('filelock')
datas += copy_metadata('numpy')
datas += copy_metadata('tokenizers')
datas += copy_metadata('importlib_metadata')
datas += copy_metadata('torch-mlir')
datas += copy_metadata('omegaconf')
datas += copy_metadata('safetensors')
datas += copy_metadata('huggingface-hub')
datas += copy_metadata('sentencepiece')
datas += copy_metadata("pyyaml")
datas += collect_data_files("tokenizers")
datas += collect_data_files("tiktoken")
datas += collect_data_files("accelerate")
datas += collect_data_files('diffusers')
datas += collect_data_files('transformers')
datas += collect_data_files('opencv-python')
datas += collect_data_files('pytorch_lightning')
datas += collect_data_files('skimage')
datas += collect_data_files('gradio')
datas += collect_data_files('gradio_client')
datas += collect_data_files('iree')
datas += collect_data_files('google-cloud-storage')
datas += collect_data_files('py-cpuinfo')
datas += collect_data_files("shark", include_py_files=True)
datas += collect_data_files("timm", include_py_files=True)
datas += collect_data_files("tqdm")
datas += collect_data_files("tkinter")
datas += collect_data_files("webview")
datas += collect_data_files("sentencepiece")
datas += collect_data_files("jsonschema")
datas += collect_data_files("jsonschema_specifications")
datas += collect_data_files("cpuinfo")
datas += collect_data_files("langchain")
binaries = []
block_cipher = None
hiddenimports = ['shark', 'shark.shark_inference', 'apps']
hiddenimports += [x for x in collect_submodules("skimage") if "tests" not in x]
hiddenimports += [x for x in collect_submodules("iree") if "tests" not in x]
a = Analysis(
['scripts/vicuna.py'],
pathex=['.'],
binaries=binaries,
datas=datas,
hiddenimports=hiddenimports,
hookspath=[],
hooksconfig={},
runtime_hooks=[],
excludes=[],
win_no_prefer_redirects=False,
win_private_assemblies=False,
cipher=block_cipher,
noarchive=False,
)
pyz = PYZ(a.pure, a.zipped_data, cipher=block_cipher)
exe = EXE(
pyz,
a.scripts,
a.binaries,
a.zipfiles,
a.datas,
[],
name='shark_llama_cli',
debug=False,
bootloader_ignore_signals=False,
strip=False,
upx=True,
upx_exclude=[],
runtime_tmpdir=None,
console=True,
disable_windowed_traceback=False,
argv_emulation=False,
target_arch=None,
codesign_identity=None,
entitlements_file=None,
)

View File

@@ -1,876 +0,0 @@
import argparse
import json
import re
from io import BytesIO
from pathlib import Path
from tqdm import tqdm
from typing import List, Optional, Tuple, Union
import numpy as np
import iree.runtime
import itertools
import subprocess
import torch
import torch_mlir
from torch_mlir import TensorPlaceholder
from torch_mlir.compiler_utils import run_pipeline_with_repro_report
from transformers import (
AutoTokenizer,
AutoModelForCausalLM,
LlamaPreTrainedModel,
)
from transformers.modeling_outputs import (
BaseModelOutputWithPast,
CausalLMOutputWithPast,
SequenceClassifierOutputWithPast,
)
from transformers.modeling_utils import PreTrainedModel
from transformers.utils import (
add_start_docstrings,
add_start_docstrings_to_model_forward,
logging,
replace_return_docstrings,
)
from apps.language_models.src.pipelines.SharkLLMBase import SharkLLMBase
from apps.language_models.src.model_wrappers.vicuna_sharded_model import (
FirstVicunaLayer,
SecondVicunaLayer,
CompiledVicunaLayer,
ShardedVicunaModel,
LMHead,
LMHeadCompiled,
VicunaEmbedding,
VicunaEmbeddingCompiled,
VicunaNorm,
VicunaNormCompiled,
)
from apps.language_models.src.model_wrappers.vicuna_model import (
FirstVicuna,
SecondVicuna7B,
)
from apps.language_models.utils import (
get_vmfb_from_path,
)
from shark.shark_downloader import download_public_file
from shark.shark_importer import get_f16_inputs
from shark.shark_inference import SharkInference
from transformers.models.llama.configuration_llama import LlamaConfig
from transformers.models.llama.modeling_llama import (
LlamaDecoderLayer,
LlamaRMSNorm,
_make_causal_mask,
_expand_mask,
)
from torch import nn
from time import time
class LlamaModel(LlamaPreTrainedModel):
"""
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`]
Args:
config: LlamaConfig
"""
def __init__(self, config: LlamaConfig):
super().__init__(config)
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.embed_tokens = nn.Embedding(
config.vocab_size, config.hidden_size, self.padding_idx
)
self.layers = nn.ModuleList(
[
LlamaDecoderLayer(config)
for _ in range(config.num_hidden_layers)
]
)
self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.gradient_checkpointing = False
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.embed_tokens
def set_input_embeddings(self, value):
self.embed_tokens = value
# Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask
def _prepare_decoder_attention_mask(
self,
attention_mask,
input_shape,
inputs_embeds,
past_key_values_length,
):
# create causal mask
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
combined_attention_mask = None
if input_shape[-1] > 1:
combined_attention_mask = _make_causal_mask(
input_shape,
inputs_embeds.dtype,
device=inputs_embeds.device,
past_key_values_length=past_key_values_length,
)
if attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
expanded_attn_mask = _expand_mask(
attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
).to(inputs_embeds.device)
combined_attention_mask = (
expanded_attn_mask
if combined_attention_mask is None
else expanded_attn_mask + combined_attention_mask
)
return combined_attention_mask
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
):
t1 = time()
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
use_cache = (
use_cache if use_cache is not None else self.config.use_cache
)
return_dict = (
return_dict
if return_dict is not None
else self.config.use_return_dict
)
# retrieve input_ids and inputs_embeds
if input_ids is not None and inputs_embeds is not None:
raise ValueError(
"You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time"
)
elif input_ids is not None:
batch_size, seq_length = input_ids.shape
elif inputs_embeds is not None:
batch_size, seq_length, _ = inputs_embeds.shape
else:
raise ValueError(
"You have to specify either decoder_input_ids or decoder_inputs_embeds"
)
seq_length_with_past = seq_length
past_key_values_length = 0
if past_key_values is not None:
past_key_values_length = past_key_values[0][0].shape[2]
seq_length_with_past = (
seq_length_with_past + past_key_values_length
)
if position_ids is None:
device = (
input_ids.device
if input_ids is not None
else inputs_embeds.device
)
position_ids = torch.arange(
past_key_values_length,
seq_length + past_key_values_length,
dtype=torch.long,
device=device,
)
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
else:
position_ids = position_ids.view(-1, seq_length).long()
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
# embed positions
if attention_mask is None:
attention_mask = torch.ones(
(batch_size, seq_length_with_past),
dtype=torch.bool,
device=inputs_embeds.device,
)
attention_mask = self._prepare_decoder_attention_mask(
attention_mask,
(batch_size, seq_length),
inputs_embeds,
past_key_values_length,
)
hidden_states = inputs_embeds
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
next_decoder_cache = () if use_cache else None
for idx, decoder_layer in enumerate(self.compressedlayers):
if output_hidden_states:
all_hidden_states += (hidden_states,)
past_key_value = (
past_key_values[8 * idx : 8 * (idx + 1)]
if past_key_values is not None
else None
)
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
# None for past_key_value
return module(*inputs, output_attentions, None)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(decoder_layer),
hidden_states,
attention_mask,
position_ids,
None,
)
else:
layer_outputs = decoder_layer.forward(
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
)
hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache += (layer_outputs[1:],)
if output_attentions:
all_self_attns += (layer_outputs[1],)
try:
hidden_states = np.asarray(hidden_states, hidden_states.dtype)
except:
_ = 10
hidden_states = self.norm(hidden_states)
# add hidden states from the last decoder layer
if output_hidden_states:
all_hidden_states += (hidden_states,)
next_cache = next_decoder_cache if use_cache else None
next_cache = tuple(itertools.chain.from_iterable(next_cache))
print(f"Token generated in {time() - t1} seconds")
if not return_dict:
return tuple(
v
for v in [
hidden_states,
next_cache,
all_hidden_states,
all_self_attns,
]
if v is not None
)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_self_attns,
)
class EightLayerLayerSV(torch.nn.Module):
def __init__(self, layers):
super().__init__()
assert len(layers) == 8
self.layers = layers
def forward(
self,
hidden_states,
attention_mask,
position_ids,
pkv00,
pkv01,
pkv10,
pkv11,
pkv20,
pkv21,
pkv30,
pkv31,
pkv40,
pkv41,
pkv50,
pkv51,
pkv60,
pkv61,
pkv70,
pkv71,
):
pkvs = [
(pkv00, pkv01),
(pkv10, pkv11),
(pkv20, pkv21),
(pkv30, pkv31),
(pkv40, pkv41),
(pkv50, pkv51),
(pkv60, pkv61),
(pkv70, pkv71),
]
new_pkvs = []
for layer, pkv in zip(self.layers, pkvs):
outputs = layer(
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=(
pkv[0],
pkv[1],
),
use_cache=True,
)
hidden_states = outputs[0]
new_pkvs.append(
(
outputs[-1][0],
outputs[-1][1],
)
)
(
(new_pkv00, new_pkv01),
(new_pkv10, new_pkv11),
(new_pkv20, new_pkv21),
(new_pkv30, new_pkv31),
(new_pkv40, new_pkv41),
(new_pkv50, new_pkv51),
(new_pkv60, new_pkv61),
(new_pkv70, new_pkv71),
) = new_pkvs
return (
hidden_states,
new_pkv00,
new_pkv01,
new_pkv10,
new_pkv11,
new_pkv20,
new_pkv21,
new_pkv30,
new_pkv31,
new_pkv40,
new_pkv41,
new_pkv50,
new_pkv51,
new_pkv60,
new_pkv61,
new_pkv70,
new_pkv71,
)
class EightLayerLayerFV(torch.nn.Module):
def __init__(self, layers):
super().__init__()
assert len(layers) == 8
self.layers = layers
def forward(self, hidden_states, attention_mask, position_ids):
new_pkvs = []
for layer in self.layers:
outputs = layer(
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=None,
use_cache=True,
)
hidden_states = outputs[0]
new_pkvs.append(
(
outputs[-1][0],
outputs[-1][1],
)
)
(
(new_pkv00, new_pkv01),
(new_pkv10, new_pkv11),
(new_pkv20, new_pkv21),
(new_pkv30, new_pkv31),
(new_pkv40, new_pkv41),
(new_pkv50, new_pkv51),
(new_pkv60, new_pkv61),
(new_pkv70, new_pkv71),
) = new_pkvs
return (
hidden_states,
new_pkv00,
new_pkv01,
new_pkv10,
new_pkv11,
new_pkv20,
new_pkv21,
new_pkv30,
new_pkv31,
new_pkv40,
new_pkv41,
new_pkv50,
new_pkv51,
new_pkv60,
new_pkv61,
new_pkv70,
new_pkv71,
)
class CompiledEightLayerLayerSV(torch.nn.Module):
def __init__(self, model):
super().__init__()
self.model = model
def forward(
self,
hidden_states,
attention_mask,
position_ids,
past_key_value,
output_attentions=False,
use_cache=True,
):
hidden_states = hidden_states.detach()
attention_mask = attention_mask.detach()
position_ids = position_ids.detach()
(
(pkv00, pkv01),
(pkv10, pkv11),
(pkv20, pkv21),
(pkv30, pkv31),
(pkv40, pkv41),
(pkv50, pkv51),
(pkv60, pkv61),
(pkv70, pkv71),
) = past_key_value
pkv00 = pkv00.detatch()
pkv01 = pkv01.detatch()
pkv10 = pkv10.detatch()
pkv11 = pkv11.detatch()
pkv20 = pkv20.detatch()
pkv21 = pkv21.detatch()
pkv30 = pkv30.detatch()
pkv31 = pkv31.detatch()
pkv40 = pkv40.detatch()
pkv41 = pkv41.detatch()
pkv50 = pkv50.detatch()
pkv51 = pkv51.detatch()
pkv60 = pkv60.detatch()
pkv61 = pkv61.detatch()
pkv70 = pkv70.detatch()
pkv71 = pkv71.detatch()
output = self.model(
"forward",
(
hidden_states,
attention_mask,
position_ids,
pkv00,
pkv01,
pkv10,
pkv11,
pkv20,
pkv21,
pkv30,
pkv31,
pkv40,
pkv41,
pkv50,
pkv51,
pkv60,
pkv61,
pkv70,
pkv71,
),
send_to_host=False,
)
return (
output[0],
(output[1][0], output[1][1]),
(output[2][0], output[2][1]),
(output[3][0], output[3][1]),
(output[4][0], output[4][1]),
(output[5][0], output[5][1]),
(output[6][0], output[6][1]),
(output[7][0], output[7][1]),
(output[8][0], output[8][1]),
)
def forward_compressed(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
):
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
# retrieve input_ids and inputs_embeds
if input_ids is not None and inputs_embeds is not None:
raise ValueError(
"You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time"
)
elif input_ids is not None:
batch_size, seq_length = input_ids.shape
elif inputs_embeds is not None:
batch_size, seq_length, _ = inputs_embeds.shape
else:
raise ValueError(
"You have to specify either decoder_input_ids or decoder_inputs_embeds"
)
seq_length_with_past = seq_length
past_key_values_length = 0
if past_key_values is not None:
past_key_values_length = past_key_values[0][0].shape[2]
seq_length_with_past = seq_length_with_past + past_key_values_length
if position_ids is None:
device = (
input_ids.device if input_ids is not None else inputs_embeds.device
)
position_ids = torch.arange(
past_key_values_length,
seq_length + past_key_values_length,
dtype=torch.long,
device=device,
)
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
else:
position_ids = position_ids.view(-1, seq_length).long()
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
# embed positions
if attention_mask is None:
attention_mask = torch.ones(
(batch_size, seq_length_with_past),
dtype=torch.bool,
device=inputs_embeds.device,
)
attention_mask = self._prepare_decoder_attention_mask(
attention_mask,
(batch_size, seq_length),
inputs_embeds,
past_key_values_length,
)
hidden_states = inputs_embeds
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
next_decoder_cache = () if use_cache else None
for idx, decoder_layer in enumerate(self.compressedlayers):
if output_hidden_states:
all_hidden_states += (hidden_states,)
past_key_value = (
past_key_values[8 * idx : 8 * (idx + 1)]
if past_key_values is not None
else None
)
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
# None for past_key_value
return module(*inputs, output_attentions, None)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(decoder_layer),
hidden_states,
attention_mask,
position_ids,
None,
)
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
)
hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache += (
layer_outputs[2 if output_attentions else 1],
)
if output_attentions:
all_self_attns += (layer_outputs[1],)
hidden_states = self.norm(hidden_states)
# add hidden states from the last decoder layer
if output_hidden_states:
all_hidden_states += (hidden_states,)
next_cache = next_decoder_cache if use_cache else None
if not return_dict:
return tuple(
v
for v in [
hidden_states,
next_cache,
all_hidden_states,
all_self_attns,
]
if v is not None
)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_self_attns,
)
class CompiledEightLayerLayer(torch.nn.Module):
def __init__(self, model):
super().__init__()
self.model = model
def forward(
self,
hidden_states,
attention_mask,
position_ids,
past_key_value=None,
output_attentions=False,
use_cache=True,
):
t2 = time()
if past_key_value is None:
try:
hidden_states = np.asarray(hidden_states, hidden_states.dtype)
except:
pass
attention_mask = attention_mask.detach()
position_ids = position_ids.detach()
t1 = time()
output = self.model(
"first_vicuna_forward",
(hidden_states, attention_mask, position_ids),
send_to_host=False,
)
output2 = (
output[0],
(
output[1],
output[2],
),
(
output[3],
output[4],
),
(
output[5],
output[6],
),
(
output[7],
output[8],
),
(
output[9],
output[10],
),
(
output[11],
output[12],
),
(
output[13],
output[14],
),
(
output[15],
output[16],
),
)
return output2
else:
(
(pkv00, pkv01),
(pkv10, pkv11),
(pkv20, pkv21),
(pkv30, pkv31),
(pkv40, pkv41),
(pkv50, pkv51),
(pkv60, pkv61),
(pkv70, pkv71),
) = past_key_value
try:
hidden_states = hidden_states.detach()
attention_mask = attention_mask.detach()
position_ids = position_ids.detach()
pkv00 = pkv00.detach()
pkv01 = pkv01.detach()
pkv10 = pkv10.detach()
pkv11 = pkv11.detach()
pkv20 = pkv20.detach()
pkv21 = pkv21.detach()
pkv30 = pkv30.detach()
pkv31 = pkv31.detach()
pkv40 = pkv40.detach()
pkv41 = pkv41.detach()
pkv50 = pkv50.detach()
pkv51 = pkv51.detach()
pkv60 = pkv60.detach()
pkv61 = pkv61.detach()
pkv70 = pkv70.detach()
pkv71 = pkv71.detach()
except:
x = 10
t1 = time()
if type(hidden_states) == iree.runtime.array_interop.DeviceArray:
hidden_states = np.array(hidden_states, hidden_states.dtype)
hidden_states = torch.tensor(hidden_states)
hidden_states = hidden_states.detach()
output = self.model(
"second_vicuna_forward",
(
hidden_states,
attention_mask,
position_ids,
pkv00,
pkv01,
pkv10,
pkv11,
pkv20,
pkv21,
pkv30,
pkv31,
pkv40,
pkv41,
pkv50,
pkv51,
pkv60,
pkv61,
pkv70,
pkv71,
),
send_to_host=False,
)
print(f"{time() - t1}")
del pkv00
del pkv01
del pkv10
del pkv11
del pkv20
del pkv21
del pkv30
del pkv31
del pkv40
del pkv41
del pkv50
del pkv51
del pkv60
del pkv61
del pkv70
del pkv71
output2 = (
output[0],
(
output[1],
output[2],
),
(
output[3],
output[4],
),
(
output[5],
output[6],
),
(
output[7],
output[8],
),
(
output[9],
output[10],
),
(
output[11],
output[12],
),
(
output[13],
output[14],
),
(
output[15],
output[16],
),
)
return output2

View File

@@ -1,13 +1,15 @@
import torch
from transformers import AutoModelForCausalLM
from brevitas_examples.llm.llm_quant.quantize import quantize_model
from brevitas_examples.llm.llm_quant.run_utils import get_model_impl
class FirstVicuna(torch.nn.Module):
def __init__(
self,
model_path,
precision="fp32",
accumulates="fp32",
weight_group_size=128,
model_name="vicuna",
hf_auth_token: str = None,
@@ -16,24 +18,15 @@ class FirstVicuna(torch.nn.Module):
kwargs = {"torch_dtype": torch.float32}
if "llama2" in model_name:
kwargs["use_auth_token"] = hf_auth_token
self.accumulates = (
torch.float32 if accumulates == "fp32" else torch.float16
)
self.model = AutoModelForCausalLM.from_pretrained(
model_path, low_cpu_mem_usage=True, **kwargs
)
print(f"[DEBUG] model_path : {model_path}")
if precision in ["int4", "int8"]:
from brevitas_examples.llm.llm_quant.quantize import quantize_model
from brevitas_examples.llm.llm_quant.run_utils import (
get_model_impl,
)
print("First Vicuna applying weight quantization..")
weight_bit_width = 4 if precision == "int4" else 8
quantize_model(
get_model_impl(self.model).layers,
dtype=self.accumulates,
dtype=torch.float32,
weight_bit_width=weight_bit_width,
weight_param_method="stats",
weight_scale_precision="float",
@@ -47,559 +40,6 @@ class FirstVicuna(torch.nn.Module):
def forward(self, input_ids):
op = self.model(input_ids=input_ids, use_cache=True)
return_vals = []
token = torch.argmax(op.logits[:, -1, :], dim=1)
return_vals.append(token)
temp_past_key_values = op.past_key_values
for item in temp_past_key_values:
return_vals.append(item[0].transpose(1,2))
return_vals.append(item[1].transpose(1,2))
return tuple(return_vals)
class SecondVicuna7B(torch.nn.Module):
def __init__(
self,
model_path,
precision="fp32",
accumulates="fp32",
weight_group_size=128,
model_name="vicuna",
hf_auth_token: str = None,
):
super().__init__()
kwargs = {"torch_dtype": torch.float32}
if "llama2" in model_name:
kwargs["use_auth_token"] = hf_auth_token
self.model = AutoModelForCausalLM.from_pretrained(
model_path, low_cpu_mem_usage=True, **kwargs
)
self.accumulates = (
torch.float32 if accumulates == "fp32" else torch.float16
)
print(f"[DEBUG] model_path : {model_path}")
if precision in ["int4", "int8"]:
from brevitas_examples.llm.llm_quant.quantize import quantize_model
from brevitas_examples.llm.llm_quant.run_utils import (
get_model_impl,
)
print("Second Vicuna applying weight quantization..")
weight_bit_width = 4 if precision == "int4" else 8
quantize_model(
get_model_impl(self.model).layers,
dtype=self.accumulates,
weight_bit_width=weight_bit_width,
weight_param_method="stats",
weight_scale_precision="float",
weight_quant_type="asym",
weight_quant_granularity="per_group",
weight_group_size=weight_group_size,
quantize_weight_zero_point=False,
)
print("Weight quantization applied.")
def forward(
self,
i0,
i1,
i2,
i3,
i4,
i5,
i6,
i7,
i8,
i9,
i10,
i11,
i12,
i13,
i14,
i15,
i16,
i17,
i18,
i19,
i20,
i21,
i22,
i23,
i24,
i25,
i26,
i27,
i28,
i29,
i30,
i31,
i32,
i33,
i34,
i35,
i36,
i37,
i38,
i39,
i40,
i41,
i42,
i43,
i44,
i45,
i46,
i47,
i48,
i49,
i50,
i51,
i52,
i53,
i54,
i55,
i56,
i57,
i58,
i59,
i60,
i61,
i62,
i63,
i64,
):
token = i0
past_key_values = (
(i1, i2),
(
i3,
i4,
),
(
i5,
i6,
),
(
i7,
i8,
),
(
i9,
i10,
),
(
i11,
i12,
),
(
i13,
i14,
),
(
i15,
i16,
),
(
i17,
i18,
),
(
i19,
i20,
),
(
i21,
i22,
),
(
i23,
i24,
),
(
i25,
i26,
),
(
i27,
i28,
),
(
i29,
i30,
),
(
i31,
i32,
),
(
i33,
i34,
),
(
i35,
i36,
),
(
i37,
i38,
),
(
i39,
i40,
),
(
i41,
i42,
),
(
i43,
i44,
),
(
i45,
i46,
),
(
i47,
i48,
),
(
i49,
i50,
),
(
i51,
i52,
),
(
i53,
i54,
),
(
i55,
i56,
),
(
i57,
i58,
),
(
i59,
i60,
),
(
i61,
i62,
),
(
i63,
i64,
),
)
past_key_values = [(x[0].transpose(1,2), x[0].transpose(1,2)) for x in past_key_values]
past_key_values = tuple(past_key_values)
op = self.model(
input_ids=token, use_cache=True, past_key_values=past_key_values
)
return_vals = []
token = torch.argmax(op.logits[:, -1, :], dim=1)
return_vals.append(token)
temp_past_key_values = op.past_key_values
for item in temp_past_key_values:
return_vals.append(item[0].transpose(1,2))
return_vals.append(item[1].transpose(1,2))
return tuple(return_vals)
class SecondVicuna13B(torch.nn.Module):
def __init__(
self,
model_path,
precision="int8",
accumulates="fp32",
weight_group_size=128,
model_name="vicuna",
hf_auth_token: str = None,
):
super().__init__()
kwargs = {"torch_dtype": torch.float32}
if "llama2" in model_name:
kwargs["use_auth_token"] = hf_auth_token
self.model = AutoModelForCausalLM.from_pretrained(
model_path, low_cpu_mem_usage=True, **kwargs
)
self.accumulates = (
torch.float32 if accumulates == "fp32" else torch.float16
)
if precision in ["int4", "int8"]:
from brevitas_examples.llm.llm_quant.quantize import quantize_model
from brevitas_examples.llm.llm_quant.run_utils import (
get_model_impl,
)
print("Second Vicuna applying weight quantization..")
weight_bit_width = 4 if precision == "int4" else 8
quantize_model(
get_model_impl(self.model).layers,
dtype=self.accumulates,
weight_bit_width=weight_bit_width,
weight_param_method="stats",
weight_scale_precision="float",
weight_quant_type="asym",
weight_quant_granularity="per_group",
weight_group_size=weight_group_size,
quantize_weight_zero_point=False,
)
print("Weight quantization applied.")
def forward(
self,
i0,
i1,
i2,
i3,
i4,
i5,
i6,
i7,
i8,
i9,
i10,
i11,
i12,
i13,
i14,
i15,
i16,
i17,
i18,
i19,
i20,
i21,
i22,
i23,
i24,
i25,
i26,
i27,
i28,
i29,
i30,
i31,
i32,
i33,
i34,
i35,
i36,
i37,
i38,
i39,
i40,
i41,
i42,
i43,
i44,
i45,
i46,
i47,
i48,
i49,
i50,
i51,
i52,
i53,
i54,
i55,
i56,
i57,
i58,
i59,
i60,
i61,
i62,
i63,
i64,
i65,
i66,
i67,
i68,
i69,
i70,
i71,
i72,
i73,
i74,
i75,
i76,
i77,
i78,
i79,
i80,
):
token = i0
past_key_values = (
(i1, i2),
(
i3,
i4,
),
(
i5,
i6,
),
(
i7,
i8,
),
(
i9,
i10,
),
(
i11,
i12,
),
(
i13,
i14,
),
(
i15,
i16,
),
(
i17,
i18,
),
(
i19,
i20,
),
(
i21,
i22,
),
(
i23,
i24,
),
(
i25,
i26,
),
(
i27,
i28,
),
(
i29,
i30,
),
(
i31,
i32,
),
(
i33,
i34,
),
(
i35,
i36,
),
(
i37,
i38,
),
(
i39,
i40,
),
(
i41,
i42,
),
(
i43,
i44,
),
(
i45,
i46,
),
(
i47,
i48,
),
(
i49,
i50,
),
(
i51,
i52,
),
(
i53,
i54,
),
(
i55,
i56,
),
(
i57,
i58,
),
(
i59,
i60,
),
(
i61,
i62,
),
(
i63,
i64,
),
(
i65,
i66,
),
(
i67,
i68,
),
(
i69,
i70,
),
(
i71,
i72,
),
(
i73,
i74,
),
(
i75,
i76,
),
(
i77,
i78,
),
(
i79,
i80,
),
)
op = self.model(
input_ids=token, use_cache=True, past_key_values=past_key_values
)
return_vals = []
return_vals.append(op.logits)
temp_past_key_values = op.past_key_values
for item in temp_past_key_values:
@@ -608,12 +48,11 @@ class SecondVicuna13B(torch.nn.Module):
return tuple(return_vals)
class SecondVicuna70B(torch.nn.Module):
class SecondVicuna(torch.nn.Module):
def __init__(
self,
model_path,
precision="fp32",
accumulates="fp32",
weight_group_size=128,
model_name="vicuna",
hf_auth_token: str = None,
@@ -625,21 +64,12 @@ class SecondVicuna70B(torch.nn.Module):
self.model = AutoModelForCausalLM.from_pretrained(
model_path, low_cpu_mem_usage=True, **kwargs
)
self.accumulates = (
torch.float32 if accumulates == "fp32" else torch.float16
)
print(f"[DEBUG] model_path : {model_path}")
if precision in ["int4", "int8"]:
from brevitas_examples.llm.llm_quant.quantize import quantize_model
from brevitas_examples.llm.llm_quant.run_utils import (
get_model_impl,
)
print("Second Vicuna applying weight quantization..")
weight_bit_width = 4 if precision == "int4" else 8
quantize_model(
get_model_impl(self.model).layers,
dtype=self.accumulates,
dtype=torch.float32,
weight_bit_width=weight_bit_width,
weight_param_method="stats",
weight_scale_precision="float",
@@ -717,103 +147,9 @@ class SecondVicuna70B(torch.nn.Module):
i62,
i63,
i64,
i65,
i66,
i67,
i68,
i69,
i70,
i71,
i72,
i73,
i74,
i75,
i76,
i77,
i78,
i79,
i80,
i81,
i82,
i83,
i84,
i85,
i86,
i87,
i88,
i89,
i90,
i91,
i92,
i93,
i94,
i95,
i96,
i97,
i98,
i99,
i100,
i101,
i102,
i103,
i104,
i105,
i106,
i107,
i108,
i109,
i110,
i111,
i112,
i113,
i114,
i115,
i116,
i117,
i118,
i119,
i120,
i121,
i122,
i123,
i124,
i125,
i126,
i127,
i128,
i129,
i130,
i131,
i132,
i133,
i134,
i135,
i136,
i137,
i138,
i139,
i140,
i141,
i142,
i143,
i144,
i145,
i146,
i147,
i148,
i149,
i150,
i151,
i152,
i153,
i154,
i155,
i156,
i157,
i158,
i159,
i160,
):
# input_ids = input_tuple[0]
# input_tuple = torch.unbind(pkv, dim=0)
token = i0
past_key_values = (
(i1, i2),
@@ -941,198 +277,6 @@ class SecondVicuna70B(torch.nn.Module):
i63,
i64,
),
(
i65,
i66,
),
(
i67,
i68,
),
(
i69,
i70,
),
(
i71,
i72,
),
(
i73,
i74,
),
(
i75,
i76,
),
(
i77,
i78,
),
(
i79,
i80,
),
(
i81,
i82,
),
(
i83,
i84,
),
(
i85,
i86,
),
(
i87,
i88,
),
(
i89,
i90,
),
(
i91,
i92,
),
(
i93,
i94,
),
(
i95,
i96,
),
(
i97,
i98,
),
(
i99,
i100,
),
(
i101,
i102,
),
(
i103,
i104,
),
(
i105,
i106,
),
(
i107,
i108,
),
(
i109,
i110,
),
(
i111,
i112,
),
(
i113,
i114,
),
(
i115,
i116,
),
(
i117,
i118,
),
(
i119,
i120,
),
(
i121,
i122,
),
(
i123,
i124,
),
(
i125,
i126,
),
(
i127,
i128,
),
(
i129,
i130,
),
(
i131,
i132,
),
(
i133,
i134,
),
(
i135,
i136,
),
(
i137,
i138,
),
(
i139,
i140,
),
(
i141,
i142,
),
(
i143,
i144,
),
(
i145,
i146,
),
(
i147,
i148,
),
(
i149,
i150,
),
(
i151,
i152,
),
(
i153,
i154,
),
(
i155,
i156,
),
(
i157,
i158,
),
(
i159,
i160,
),
)
op = self.model(
input_ids=token, use_cache=True, past_key_values=past_key_values
@@ -1154,17 +298,15 @@ class CombinedModel(torch.nn.Module):
):
super().__init__()
self.first_vicuna = FirstVicuna(first_vicuna_model_path)
# NOT using this path for 13B currently, hence using `SecondVicuna7B`.
self.second_vicuna = SecondVicuna7B(second_vicuna_model_path)
self.second_vicuna = SecondVicuna(second_vicuna_model_path)
def forward(self, input_ids):
first_output = self.first_vicuna(input_ids=input_ids)
# generate second vicuna
compilation_input_ids = torch.zeros([1, 1], dtype=torch.int64)
pkv = tuple(
(torch.zeros([1, 32, 19, 128], dtype=torch.float32))
for _ in range(64)
)
secondVicunaCompileInput = (compilation_input_ids,) + pkv
second_output = self.second_vicuna(*secondVicunaCompileInput)
first_output = self.first_vicuna(input_ids=input_ids, use_cache=True)
logits = first_output[0]
pkv = first_output[1:]
token = torch.argmax(torch.tensor(logits)[:, -1, :], dim=1)
token = token.to(torch.int64).reshape([1, 1])
secondVicunaInput = (token,) + tuple(pkv)
second_output = self.second_vicuna(secondVicunaInput)
return second_output

File diff suppressed because it is too large Load Diff

View File

@@ -66,7 +66,7 @@ class ShardedVicunaModel(torch.nn.Module):
def __init__(self, model, layers, lmhead, embedding, norm):
super().__init__()
self.model = model
# assert len(layers) == len(model.model.layers)
assert len(layers) == len(model.model.layers)
self.model.model.config.use_cache = True
self.model.model.config.output_attentions = False
self.layers = layers
@@ -132,10 +132,7 @@ class VicunaNormCompiled(torch.nn.Module):
self.model = shark_module
def forward(self, hidden_states):
try:
hidden_states.detach()
except:
pass
hidden_states.detach()
output = self.model("forward", (hidden_states,))
output = torch.tensor(output)
return output

View File

@@ -3,10 +3,7 @@ from abc import ABC, abstractmethod
class SharkLLMBase(ABC):
def __init__(
self,
model_name,
hf_model_path=None,
max_num_tokens=512,
self, model_name, hf_model_path=None, max_num_tokens=512
) -> None:
self.model_name = model_name
self.hf_model_path = hf_model_path

View File

@@ -7,9 +7,9 @@ from io import BytesIO
from pathlib import Path
from contextlib import redirect_stdout
from shark.shark_downloader import download_public_file
from shark.shark_importer import import_with_fx, save_mlir
from shark.shark_importer import import_with_fx
from shark.shark_inference import SharkInference
from transformers import AutoTokenizer, AutoModelForCausalLM, GPTQConfig
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers.generation import (
GenerationConfig,
LogitsProcessorList,
@@ -28,11 +28,9 @@ parser = argparse.ArgumentParser(
description="runs a falcon model",
)
parser.add_argument("--falcon_variant_to_use", default="7b", help="7b, 40b")
parser.add_argument(
"--falcon_variant_to_use", default="7b", help="7b, 40b, 180b"
)
parser.add_argument(
"--precision", "-p", default="fp16", choices=["fp32", "fp16", "int4"]
"--precision", "-p", default="fp16", help="fp32, fp16, int8, int4"
)
parser.add_argument("--device", "-d", default="cuda", help="vulkan, cpu, cuda")
parser.add_argument(
@@ -51,7 +49,7 @@ parser.add_argument(
)
parser.add_argument(
"--load_mlir_from_shark_tank",
default=True,
default=False,
action=argparse.BooleanOptionalAction,
help="download precompile mlir from shark tank",
)
@@ -61,52 +59,32 @@ parser.add_argument(
action=argparse.BooleanOptionalAction,
help="Run model in cli mode",
)
parser.add_argument(
"--hf_auth_token",
type=str,
default=None,
help="Specify your own huggingface authentication token for falcon-180B model.",
)
class Falcon(SharkLLMBase):
def __init__(
self,
model_name,
hf_model_path="tiiuae/falcon-7b-instruct",
hf_auth_token: str = None,
hf_model_path,
max_num_tokens=150,
device="cuda",
precision="fp32",
falcon_mlir_path=None,
falcon_vmfb_path=None,
debug=False,
) -> None:
super().__init__(model_name, hf_model_path, max_num_tokens)
print("hf_model_path: ", self.hf_model_path)
if "180b" in self.model_name and hf_auth_token == None:
raise ValueError(
""" HF auth token required for falcon-180b. Pass it using
--hf_auth_token flag. You can ask for the access to the model
here: https://huggingface.co/tiiuae/falcon-180B-chat."""
)
self.hf_auth_token = hf_auth_token
self.max_padding_length = 100
self.device = device
self.precision = precision
self.falcon_vmfb_path = falcon_vmfb_path
self.falcon_mlir_path = falcon_mlir_path
self.debug = debug
self.tokenizer = self.get_tokenizer()
self.src_model = self.get_src_model()
self.shark_model = self.compile()
self.src_model = self.get_src_model()
def get_tokenizer(self):
tokenizer = AutoTokenizer.from_pretrained(
self.hf_model_path,
trust_remote_code=True,
token=self.hf_auth_token,
self.hf_model_path, trust_remote_code=True
)
tokenizer.padding_side = "left"
tokenizer.pad_token_id = 11
@@ -114,24 +92,13 @@ class Falcon(SharkLLMBase):
def get_src_model(self):
print("Loading src model: ", self.model_name)
kwargs = {
"torch_dtype": torch.float,
"trust_remote_code": True,
"token": self.hf_auth_token,
}
if self.precision == "int4":
quantization_config = GPTQConfig(bits=4, disable_exllama=True)
kwargs["quantization_config"] = quantization_config
kwargs["load_gptq_on_cpu"] = True
kwargs["device_map"] = "cpu" if self.device == "cpu" else "cuda:0"
kwargs = {"torch_dtype": torch.float, "trust_remote_code": True}
falcon_model = AutoModelForCausalLM.from_pretrained(
self.hf_model_path, **kwargs
)
if self.precision == "int4":
falcon_model = falcon_model.to(torch.float32)
return falcon_model
def compile(self):
def compile_falcon(self):
if args.use_precompiled_model:
if not self.falcon_vmfb_path.exists():
# Downloading VMFB from shark_tank
@@ -153,37 +120,37 @@ class Falcon(SharkLLMBase):
if vmfb is not None:
return vmfb
print(f"[DEBUG] vmfb not found at {self.falcon_vmfb_path.absolute()}")
print(
f"[DEBUG] vmfb not found at {self.falcon_vmfb_path.absolute()}. Trying to work with"
f"[DEBUG] mlir path { self.falcon_mlir_path} {'exists' if self.falcon_mlir_path.exists() else 'does not exist'}"
)
if self.falcon_mlir_path.exists():
print(f"[DEBUG] mlir found at {self.falcon_mlir_path.absolute()}")
with open(self.falcon_mlir_path, "rb") as f:
bytecode = f.read()
else:
mlir_generated = False
print(
f"[DEBUG] mlir not found at {self.falcon_mlir_path.absolute()}"
# Downloading MLIR from shark_tank
download_public_file(
"gs://shark_tank/falcon/"
+ "falcon_"
+ args.falcon_variant_to_use
+ "_"
+ self.precision
+ ".mlir",
self.falcon_mlir_path.absolute(),
single_file=True,
)
if args.load_mlir_from_shark_tank:
# Downloading MLIR from shark_tank
print(f"[DEBUG] Trying to download mlir from shark_tank")
download_public_file(
"gs://shark_tank/falcon/"
+ "falcon_"
+ args.falcon_variant_to_use
+ "_"
+ self.precision
+ ".mlir",
self.falcon_mlir_path.absolute(),
single_file=True,
if self.falcon_mlir_path.exists():
with open(self.falcon_mlir_path, "rb") as f:
bytecode = f.read()
mlir_generated = True
else:
raise ValueError(
f"MLIR not found at {self.falcon_mlir_path.absolute()}"
" after downloading! Please check path and try again"
)
if self.falcon_mlir_path.exists():
print(
f"[DEBUG] mlir found at {self.falcon_mlir_path.absolute()}"
)
mlir_generated = True
if not mlir_generated:
print(f"[DEBUG] generating MLIR locally")
compilation_input_ids = torch.randint(
low=1, high=10000, size=(1, 100)
)
@@ -200,10 +167,9 @@ class Falcon(SharkLLMBase):
ts_graph = import_with_fx(
model,
falconCompileInput,
is_f16=self.precision in ["fp16", "int4"],
is_f16=self.precision == "fp16",
f16_input_mask=[False, False],
mlir_type="torchscript",
is_gptq=self.precision == "int4",
)
del model
print(f"[DEBUG] generating torch mlir")
@@ -223,37 +189,35 @@ class Falcon(SharkLLMBase):
bytecode = bytecode_stream.getvalue()
del module
f_ = open(self.falcon_mlir_path, "wb")
f_.write(bytecode)
print("Saved falcon mlir at ", str(self.falcon_mlir_path))
print(f"[DEBUG] writing mlir to file")
with open(f"{self.model_name}.mlir", "wb") as f_:
with redirect_stdout(f_):
print(module.operation.get_asm())
f_.close()
del bytecode
shark_module = SharkInference(
mlir_module=self.falcon_mlir_path,
device=self.device,
mlir_dialect="linalg",
mlir_module=bytecode, device=self.device, mlir_dialect="linalg"
)
path = shark_module.save_module(
self.falcon_vmfb_path.parent.absolute(),
self.falcon_vmfb_path.stem,
extra_args=[
"--iree-hal-dump-executable-sources-to=ies",
"--iree-vm-target-truncate-unsupported-floats",
"--iree-codegen-check-ir-before-llvm-conversion=false",
"--iree-vm-bytecode-module-output-format=flatbuffer-binary",
]
+ [
"--iree-llvmcpu-use-fast-min-max-ops",
]
if self.precision == "int4"
else [],
debug=self.debug,
"--iree-spirv-index-bits=64",
],
)
print("Saved falcon vmfb at ", str(path))
shark_module.load_module(path)
return shark_module
def compile(self):
falcon_shark_model = self.compile_falcon()
return falcon_shark_model
def generate(self, prompt):
model_inputs = self.tokenizer(
prompt,
@@ -423,7 +387,7 @@ class Falcon(SharkLLMBase):
(model_inputs["input_ids"], model_inputs["attention_mask"]),
)
)
if self.precision in ["fp16", "int4"]:
if self.precision == "fp16":
outputs = outputs.to(dtype=torch.float32)
next_token_logits = outputs
@@ -502,26 +466,11 @@ if __name__ == "__main__":
else Path(args.falcon_vmfb_path)
)
if args.precision == "int4":
if args.falcon_variant_to_use == "180b":
hf_model_path_value = "TheBloke/Falcon-180B-Chat-GPTQ"
else:
hf_model_path_value = (
"TheBloke/falcon-"
+ args.falcon_variant_to_use
+ "-instruct-GPTQ"
)
else:
if args.falcon_variant_to_use == "180b":
hf_model_path_value = "tiiuae/falcon-180B-chat"
else:
hf_model_path_value = (
"tiiuae/falcon-" + args.falcon_variant_to_use + "-instruct"
)
falcon = Falcon(
model_name="falcon_" + args.falcon_variant_to_use,
hf_model_path=hf_model_path_value,
"falcon_" + args.falcon_variant_to_use,
hf_model_path="tiiuae/falcon-"
+ args.falcon_variant_to_use
+ "-instruct",
device=args.device,
precision=args.precision,
falcon_mlir_path=falcon_mlir_path,
@@ -548,11 +497,7 @@ if __name__ == "__main__":
prompt = input("Please enter the prompt text: ")
print("\nPrompt Text: ", prompt)
prompt_template = f"""A helpful assistant who helps the user with any questions asked.
User: {prompt}
Assistant:"""
res_str = falcon.generate(prompt_template)
res_str = falcon.generate(prompt)
torch.cuda.empty_cache()
gc.collect()
print(

View File

@@ -126,7 +126,7 @@ def is_url(input_url):
import os
import tempfile
from shark.shark_inference import SharkInference
from shark.shark_importer import import_with_fx, save_mlir
from shark.shark_importer import import_with_fx
import torch
import torch_mlir
from torch_mlir.compiler_utils import run_pipeline_with_repro_report
@@ -136,8 +136,7 @@ from brevitas_examples.llm.llm_quant.quantize import quantize_model
from brevitas_examples.llm.llm_quant.run_utils import get_model_impl
# fmt: off
def quantmatmul_rhs_group_quant〡shape(lhs: List[int], rhs: List[int], rhs_scale: List[int], rhs_zero_point: List[int], rhs_bit_width: int, rhs_group_size: int) -> List[int]:
def brevitasmatmul_rhs_group_quant〡shape(lhs: List[int], rhs: List[int], rhs_scale: List[int], rhs_zero_point: List[int], rhs_bit_width: int, rhs_group_size: int) -> List[int]:
if len(lhs) == 3 and len(rhs) == 2:
return [lhs[0], lhs[1], rhs[0]]
elif len(lhs) == 2 and len(rhs) == 2:
@@ -146,21 +145,20 @@ def quantmatmul_rhs_group_quant〡shape(lhs: List[int], rhs: List[int], rhs_s
raise ValueError("Input shapes not supported.")
def quantmatmul_rhs_group_quant〡dtype(lhs_rank_dtype: Tuple[int, int], rhs_rank_dtype: Tuple[int, int], rhs_scale_rank_dtype: Tuple[int, int], rhs_zero_point_rank_dtype: Tuple[int, int], rhs_bit_width: int, rhs_group_size: int) -> int:
def brevitasmatmul_rhs_group_quant〡dtype(lhs_rank_dtype: Tuple[int, int], rhs_rank_dtype: Tuple[int, int], rhs_scale_rank_dtype: Tuple[int, int], rhs_zero_point_rank_dtype: Tuple[int, int], rhs_bit_width: int, rhs_group_size: int) -> int:
# output dtype is the dtype of the lhs float input
lhs_rank, lhs_dtype = lhs_rank_dtype
return lhs_dtype
def quantmatmul_rhs_group_quant〡has_value_semantics(lhs, rhs, rhs_scale, rhs_zero_point, rhs_bit_width, rhs_group_size) -> None:
def brevitasmatmul_rhs_group_quant〡has_value_semantics(lhs, rhs, rhs_scale, rhs_zero_point, rhs_bit_width, rhs_group_size) -> None:
return
brevitas_matmul_rhs_group_quant_library = [
quantmatmul_rhs_group_quant〡shape,
quantmatmul_rhs_group_quant〡dtype,
quantmatmul_rhs_group_quant〡has_value_semantics]
# fmt: on
brevitasmatmul_rhs_group_quant〡shape,
brevitasmatmul_rhs_group_quant〡dtype,
brevitasmatmul_rhs_group_quant〡has_value_semantics]
def load_vmfb(extended_model_name, device, mlir_dialect, extra_args=[]):
@@ -178,7 +176,7 @@ def load_vmfb(extended_model_name, device, mlir_dialect, extra_args=[]):
def compile_module(
shark_module, extended_model_name, generate_vmfb, extra_args=[], debug=False,
shark_module, extended_model_name, generate_vmfb, extra_args=[]
):
if generate_vmfb:
vmfb_path = os.path.join(os.getcwd(), extended_model_name + ".vmfb")
@@ -190,7 +188,7 @@ def compile_module(
"No vmfb found. Compiling and saving to {}".format(vmfb_path)
)
path = shark_module.save_module(
os.getcwd(), extended_model_name, extra_args, debug=debug
os.getcwd(), extended_model_name, extra_args
)
shark_module.load_module(path, extra_args=extra_args)
else:
@@ -199,7 +197,7 @@ def compile_module(
def compile_int_precision(
model, inputs, precision, device, generate_vmfb, extended_model_name, debug=False
model, inputs, precision, device, generate_vmfb, extended_model_name
):
torchscript_module = import_with_fx(
model,
@@ -211,7 +209,7 @@ def compile_int_precision(
torchscript_module,
inputs,
output_type="torch",
backend_legal_ops=["quant.matmul_rhs_group_quant"],
backend_legal_ops=["brevitas.matmul_rhs_group_quant"],
extra_library=brevitas_matmul_rhs_group_quant_library,
use_tracing=False,
verbose=False,
@@ -219,7 +217,7 @@ def compile_int_precision(
print(f"[DEBUG] converting torch to linalg")
run_pipeline_with_repro_report(
mlir_module,
"builtin.module(func.func(torch-unpack-quant-tensor),func.func(torch-convert-custom-quant-op),torch-backend-to-linalg-on-tensors-backend-pipeline)",
"builtin.module(func.func(torch-unpack-torch-tensor),torch-backend-to-linalg-on-tensors-backend-pipeline)",
description="Lowering Torch Backend IR -> Linalg-on-Tensors Backend IR",
)
from contextlib import redirect_stdout
@@ -235,12 +233,6 @@ def compile_int_precision(
mlir_module = BytesIO(mlir_module)
bytecode = mlir_module.read()
print(f"Elided IR written for {extended_model_name}")
bytecode = save_mlir(
bytecode,
model_name=extended_model_name,
frontend="torch",
dir=os.getcwd(),
)
return bytecode
shark_module = SharkInference(
mlir_module=bytecode, device=device, mlir_dialect="tm_tensor"
@@ -257,7 +249,6 @@ def compile_int_precision(
extended_model_name=extended_model_name,
generate_vmfb=generate_vmfb,
extra_args=extra_args,
debug=debug,
),
bytecode,
)
@@ -301,7 +292,6 @@ def shark_compile_through_fx_int(
device,
generate_or_load_vmfb,
extended_model_name,
debug,
)
extra_args = [
"--iree-hal-dump-executable-sources-to=ies",

View File

@@ -32,13 +32,11 @@ class SharkStableLM(SharkLLMBase):
max_num_tokens=512,
device="cuda",
precision="fp32",
debug="False",
) -> None:
super().__init__(model_name, hf_model_path, max_num_tokens)
self.max_sequence_len = 256
self.device = device
self.precision = precision
self.debug = debug
self.tokenizer = self.get_tokenizer()
self.shark_model = self.compile()
@@ -113,7 +111,7 @@ class SharkStableLM(SharkLLMBase):
shark_module.compile()
path = shark_module.save_module(
vmfb_path.parent.absolute(), vmfb_path.stem, debug=self.debug
vmfb_path.parent.absolute(), vmfb_path.stem
)
print("Saved vmfb at ", str(path))

View File

@@ -8,7 +8,7 @@ from shark.shark_downloader import download_public_file
# expects a Path / str as arg
# returns None if path not found or SharkInference module
def get_vmfb_from_path(vmfb_path, device, mlir_dialect, device_id=None):
def get_vmfb_from_path(vmfb_path, device, mlir_dialect):
if not isinstance(vmfb_path, Path):
vmfb_path = Path(vmfb_path)
@@ -20,7 +20,7 @@ def get_vmfb_from_path(vmfb_path, device, mlir_dialect, device_id=None):
print("Loading vmfb from: ", vmfb_path)
print("Device from get_vmfb_from_path - ", device)
shark_module = SharkInference(
None, device=device, mlir_dialect=mlir_dialect, device_idx=device_id
None, device=device, mlir_dialect=mlir_dialect
)
shark_module.load_module(vmfb_path)
print("Successfully loaded vmfb")
@@ -28,13 +28,7 @@ def get_vmfb_from_path(vmfb_path, device, mlir_dialect, device_id=None):
def get_vmfb_from_config(
shark_container,
model,
precision,
device,
vmfb_path,
padding=None,
device_id=None,
shark_container, model, precision, device, vmfb_path, padding=None
):
vmfb_url = (
f"gs://shark_tank/{shark_container}/{model}_{precision}_{device}"
@@ -43,6 +37,4 @@ def get_vmfb_from_config(
vmfb_url = vmfb_url + f"_{padding}"
vmfb_url = vmfb_url + ".vmfb"
download_public_file(vmfb_url, vmfb_path.absolute(), single_file=True)
return get_vmfb_from_path(
vmfb_path, device, "tm_tensor", device_id=device_id
)
return get_vmfb_from_path(vmfb_path, device, "tm_tensor")

View File

@@ -7,16 +7,16 @@ Compile Commands FP32/FP16:
```shell
Vulkan AMD:
iree-compile --iree-input-type=none --iree-hal-target-backends=vulkan --iree-vulkan-target-triple=rdna2-unknown-linux /path/to/input/mlir -o /path/to/output/vmfb
iree-compile --iree-input-type=none --iree-hal-target-backends=vulkan --iree-vulkan-target-triple=rdna2-unknown-linux --iree-stream-resource-index-bits=64 --iree-vm-target-index-bits=64 /path/to/input/mlir -o /path/to/output/vmfb
# add --mlir-print-debuginfo --mlir-print-op-on-diagnostic=true for debug
# use iree-input-type=auto or "mhlo_legacy" or "stablehlo" for TF models
CUDA NVIDIA:
iree-compile --iree-input-type=none --iree-hal-target-backends=cuda /path/to/input/mlir -o /path/to/output/vmfb
iree-compile --iree-input-type=none --iree-hal-target-backends=cuda --iree-stream-resource-index-bits=64 --iree-vm-target-index-bits=64 /path/to/input/mlir -o /path/to/output/vmfb
CPU:
iree-compile --iree-input-type=none --iree-hal-target-backends=llvm-cpu /path/to/input/mlir -o /path/to/output/vmfb
iree-compile --iree-input-type=none --iree-hal-target-backends=llvm-cpu --iree-stream-resource-index-bits=64 --iree-vm-target-index-bits=64 /path/to/input/mlir -o /path/to/output/vmfb
```

View File

@@ -34,7 +34,7 @@ from PIL import Image
from tqdm.auto import tqdm
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from diffusers.loaders import AttnProcsLayers
from diffusers.models.attention_processor import LoRAXFormersAttnProcessor
from diffusers.models.cross_attention import LoRACrossAttnProcessor
import torch_mlir
from torch_mlir.dynamo import make_simple_dynamo_backend
@@ -287,7 +287,7 @@ def lora_train(
block_id = int(name[len("down_blocks.")])
hidden_size = unet.config.block_out_channels[block_id]
lora_attn_procs[name] = LoRAXFormersAttnProcessor(
lora_attn_procs[name] = LoRACrossAttnProcessor(
hidden_size=hidden_size,
cross_attention_dim=cross_attention_dim,
)

View File

@@ -15,8 +15,8 @@ pathex = [
# datafiles for pyinstaller
datas = []
datas += collect_data_files("torch")
datas += copy_metadata("torch")
datas += copy_metadata("tokenizers")
datas += copy_metadata("tqdm")
datas += copy_metadata("regex")
datas += copy_metadata("requests")
@@ -30,29 +30,26 @@ datas += copy_metadata("safetensors")
datas += copy_metadata("Pillow")
datas += copy_metadata("sentencepiece")
datas += copy_metadata("pyyaml")
datas += copy_metadata("huggingface-hub")
datas += collect_data_files("torch")
datas += collect_data_files("tokenizers")
datas += collect_data_files("tiktoken")
datas += collect_data_files("accelerate")
datas += collect_data_files("diffusers")
datas += collect_data_files("transformers")
datas += collect_data_files("pytorch_lightning")
datas += collect_data_files("opencv_python")
datas += collect_data_files("skimage")
datas += collect_data_files("gradio")
datas += collect_data_files("gradio_client")
datas += collect_data_files("iree")
datas += collect_data_files("shark", include_py_files=True)
datas += collect_data_files("google_cloud_storage")
datas += collect_data_files("shark")
datas += collect_data_files("timm", include_py_files=True)
datas += collect_data_files("tqdm")
datas += collect_data_files("tkinter")
datas += collect_data_files("webview")
datas += collect_data_files("sentencepiece")
datas += collect_data_files("jsonschema")
datas += collect_data_files("jsonschema_specifications")
datas += collect_data_files("cpuinfo")
datas += collect_data_files("langchain")
datas += collect_data_files("cv2")
datas += [
("src/utils/resources/prompts.json", "resources"),
("src/utils/resources/model_db.json", "resources"),
@@ -75,13 +72,6 @@ datas += [
hiddenimports = ["shark", "shark.shark_inference", "apps"]
hiddenimports += [x for x in collect_submodules("skimage") if "tests" not in x]
hiddenimports += [
x for x in collect_submodules("diffusers") if "tests" not in x
]
blacklist = ["tests", "convert"]
hiddenimports += [
x
for x in collect_submodules("transformers")
if not any(kw in x for kw in blacklist)
x for x in collect_submodules("transformers") if "tests" not in x
]
hiddenimports += [x for x in collect_submodules("iree") if "tests" not in x]
hiddenimports += ["iree._runtime", "iree.compiler._mlir_libs._mlir.ir"]

View File

@@ -177,11 +177,9 @@ class SharkifyStableDiffusionModel:
"unet",
"unet512",
"stencil_unet",
"stencil_unet_512",
"vae",
"vae_encode",
"stencil_adaptor",
"stencil_adaptor_512",
]
index = 0
for model in sub_model_list:
@@ -341,7 +339,7 @@ class SharkifyStableDiffusionModel:
)
return shark_vae, vae_mlir
def get_controlled_unet(self, use_large=False):
def get_controlled_unet(self):
class ControlledUnetModel(torch.nn.Module):
def __init__(
self,
@@ -417,16 +415,6 @@ class SharkifyStableDiffusionModel:
is_f16 = True if self.precision == "fp16" else False
inputs = tuple(self.inputs["unet"])
model_name = "stencil_unet"
if use_large:
pad = (0, 0) * (len(inputs[2].shape) - 2)
pad = pad + (0, 512 - inputs[2].shape[1])
inputs = (
inputs[:2]
+ (torch.nn.functional.pad(inputs[2], pad),)
+ inputs[3:]
)
model_name = "stencil_unet_512"
input_mask = [
True,
True,
@@ -449,19 +437,19 @@ class SharkifyStableDiffusionModel:
shark_controlled_unet, controlled_unet_mlir = compile_through_fx(
unet,
inputs,
extended_model_name=self.model_name[model_name],
extended_model_name=self.model_name["stencil_unet"],
is_f16=is_f16,
f16_input_mask=input_mask,
use_tuned=self.use_tuned,
extra_args=get_opt_flags("unet", precision=self.precision),
base_model_id=self.base_model_id,
model_name=model_name,
model_name="stencil_unet",
precision=self.precision,
return_mlir=self.return_mlir,
)
return shark_controlled_unet, controlled_unet_mlir
def get_control_net(self, use_large=False):
def get_control_net(self):
class StencilControlNetModel(torch.nn.Module):
def __init__(
self, model_id=self.use_stencil, low_cpu_mem_usage=False
@@ -509,34 +497,17 @@ class SharkifyStableDiffusionModel:
is_f16 = True if self.precision == "fp16" else False
inputs = tuple(self.inputs["stencil_adaptor"])
if use_large:
pad = (0, 0) * (len(inputs[2].shape) - 2)
pad = pad + (0, 512 - inputs[2].shape[1])
inputs = (
inputs[0],
inputs[1],
torch.nn.functional.pad(inputs[2], pad),
inputs[3],
)
save_dir = os.path.join(
self.sharktank_dir, self.model_name["stencil_adaptor_512"]
)
else:
save_dir = os.path.join(
self.sharktank_dir, self.model_name["stencil_adaptor"]
)
input_mask = [True, True, True, True]
model_name = "stencil_adaptor" if use_large else "stencil_adaptor_512"
shark_cnet, cnet_mlir = compile_through_fx(
scnet,
inputs,
extended_model_name=self.model_name[model_name],
extended_model_name=self.model_name["stencil_adaptor"],
is_f16=is_f16,
f16_input_mask=input_mask,
use_tuned=self.use_tuned,
extra_args=get_opt_flags("unet", precision=self.precision),
base_model_id=self.base_model_id,
model_name=model_name,
model_name="stencil_adaptor",
precision=self.precision,
return_mlir=self.return_mlir,
)
@@ -710,11 +681,8 @@ class SharkifyStableDiffusionModel:
return self.text_encoder(input)[0]
clip_model = CLIPText(low_cpu_mem_usage=self.low_cpu_mem_usage)
save_dir = ""
save_dir = os.path.join(self.sharktank_dir, self.model_name["clip"])
if self.debug:
save_dir = os.path.join(
self.sharktank_dir, self.model_name["clip"]
)
os.makedirs(
save_dir,
exist_ok=True,
@@ -780,7 +748,7 @@ class SharkifyStableDiffusionModel:
else:
return self.get_unet(use_large=use_large)
else:
return self.get_controlled_unet(use_large=use_large)
return self.get_controlled_unet()
def vae_encode(self):
try:
@@ -879,14 +847,12 @@ class SharkifyStableDiffusionModel:
except Exception as e:
sys.exit(e)
def controlnet(self, use_large=False):
def controlnet(self):
try:
self.inputs["stencil_adaptor"] = self.get_input_info_for(
base_models["stencil_adaptor"]
)
compiled_stencil_adaptor, controlnet_mlir = self.get_control_net(
use_large=use_large
)
compiled_stencil_adaptor, controlnet_mlir = self.get_control_net()
check_compilation(compiled_stencil_adaptor, "Stencil")
if self.return_mlir:

View File

@@ -84,35 +84,13 @@ class Image2ImagePipeline(StableDiffusionPipeline):
num_inference_steps,
strength,
dtype,
resample_type,
):
# Pre process image -> get image encoded -> process latents
# TODO: process with variable HxW combos
# Pre-process image
if resample_type == "Lanczos":
resample_type = Image.LANCZOS
elif resample_type == "Nearest Neighbor":
resample_type = Image.NEAREST
elif resample_type == "Bilinear":
resample_type = Image.BILINEAR
elif resample_type == "Bicubic":
resample_type = Image.BICUBIC
elif resample_type == "Adaptive":
resample_type = Image.ADAPTIVE
elif resample_type == "Antialias":
resample_type = Image.ANTIALIAS
elif resample_type == "Box":
resample_type = Image.BOX
elif resample_type == "Affine":
resample_type = Image.AFFINE
elif resample_type == "Cubic":
resample_type = Image.CUBIC
else: # Fallback to Lanczos
resample_type = Image.LANCZOS
image = image.resize((width, height), resample=resample_type)
# Pre process image
image = image.resize((width, height))
image_arr = np.stack([np.array(i) for i in (image,)], axis=0)
image_arr = image_arr / 255.0
image_arr = torch.from_numpy(image_arr).permute(0, 3, 1, 2).to(dtype)
@@ -169,7 +147,6 @@ class Image2ImagePipeline(StableDiffusionPipeline):
cpu_scheduling,
max_embeddings_multiples,
use_stencil,
resample_type,
):
# prompts and negative prompts must be a list.
if isinstance(prompts, str):
@@ -209,7 +186,6 @@ class Image2ImagePipeline(StableDiffusionPipeline):
num_inference_steps=num_inference_steps,
strength=strength,
dtype=dtype,
resample_type=resample_type,
)
# Get Image latents

View File

@@ -58,7 +58,6 @@ class StencilPipeline(StableDiffusionPipeline):
):
super().__init__(scheduler, sd_model, import_mlir, use_lora, ondemand)
self.controlnet = None
self.controlnet_512 = None
def load_controlnet(self):
if self.controlnet is not None:
@@ -69,15 +68,6 @@ class StencilPipeline(StableDiffusionPipeline):
del self.controlnet
self.controlnet = None
def load_controlnet_512(self):
if self.controlnet_512 is not None:
return
self.controlnet_512 = self.sd_model.controlnet(use_large=True)
def unload_controlnet_512(self):
del self.controlnet_512
self.controlnet_512 = None
def prepare_latents(
self,
batch_size,
@@ -121,12 +111,8 @@ class StencilPipeline(StableDiffusionPipeline):
latent_history = [latents]
text_embeddings = torch.from_numpy(text_embeddings).to(dtype)
text_embeddings_numpy = text_embeddings.detach().numpy()
if text_embeddings.shape[1] <= self.model_max_length:
self.load_unet()
self.load_controlnet()
else:
self.load_unet_512()
self.load_controlnet_512()
self.load_unet()
self.load_controlnet()
for i, t in tqdm(enumerate(total_timesteps)):
step_start_time = time.time()
timestep = torch.tensor([t]).to(dtype)
@@ -149,82 +135,43 @@ class StencilPipeline(StableDiffusionPipeline):
).to(dtype)
else:
latent_model_input_1 = latent_model_input
if text_embeddings.shape[1] <= self.model_max_length:
control = self.controlnet(
"forward",
(
latent_model_input_1,
timestep,
text_embeddings,
controlnet_hint,
),
send_to_host=False,
)
else:
control = self.controlnet_512(
"forward",
(
latent_model_input_1,
timestep,
text_embeddings,
controlnet_hint,
),
send_to_host=False,
)
control = self.controlnet(
"forward",
(
latent_model_input_1,
timestep,
text_embeddings,
controlnet_hint,
),
send_to_host=False,
)
timestep = timestep.detach().numpy()
# Profiling Unet.
profile_device = start_profiling(file_path="unet.rdc")
# TODO: Pass `control` as it is to Unet. Same as TODO mentioned in model_wrappers.py.
if text_embeddings.shape[1] <= self.model_max_length:
noise_pred = self.unet(
"forward",
(
latent_model_input,
timestep,
text_embeddings_numpy,
guidance_scale,
control[0],
control[1],
control[2],
control[3],
control[4],
control[5],
control[6],
control[7],
control[8],
control[9],
control[10],
control[11],
control[12],
),
send_to_host=False,
)
else:
print(self.unet_512)
noise_pred = self.unet_512(
"forward",
(
latent_model_input,
timestep,
text_embeddings_numpy,
guidance_scale,
control[0],
control[1],
control[2],
control[3],
control[4],
control[5],
control[6],
control[7],
control[8],
control[9],
control[10],
control[11],
control[12],
),
send_to_host=False,
)
noise_pred = self.unet(
"forward",
(
latent_model_input,
timestep,
text_embeddings_numpy,
guidance_scale,
control[0],
control[1],
control[2],
control[3],
control[4],
control[5],
control[6],
control[7],
control[8],
control[9],
control[10],
control[11],
control[12],
),
send_to_host=False,
)
end_profiling(profile_device)
if cpu_scheduling:
@@ -244,9 +191,7 @@ class StencilPipeline(StableDiffusionPipeline):
if self.ondemand:
self.unload_unet()
self.unload_unet_512()
self.unload_controlnet()
self.unload_controlnet_512()
avg_step_time = step_time_sum / len(total_timesteps)
self.log += f"\nAverage step time: {avg_step_time}ms/it"
@@ -273,7 +218,6 @@ class StencilPipeline(StableDiffusionPipeline):
cpu_scheduling,
max_embeddings_multiples,
use_stencil,
resample_type,
):
# Control Embedding check & conversion
# TODO: 1. Change `num_images_per_prompt`.

View File

@@ -84,6 +84,9 @@ class SharkEulerDiscreteScheduler(EulerDiscreteScheduler):
iree_flags.append(
f"-iree-vulkan-target-triple={args.iree_vulkan_target_triple}"
)
# Disable bindings fusion to work with moltenVK.
if sys.platform == "darwin":
iree_flags.append("-iree-stream-fuse-binding=false")
def _import(self):
scaling_model = ScalingModel()

View File

@@ -109,7 +109,7 @@ def load_lower_configs(base_model_id=None):
spec = spec.split("-")[0]
if args.annotation_model == "vae":
if not spec or spec in ["sm_80"]:
if not spec or spec in ["rdna3", "sm_80"]:
config_name = (
f"{args.annotation_model}_{args.precision}_{device}.json"
)
@@ -158,9 +158,9 @@ def load_lower_configs(base_model_id=None):
f"{spec}.json"
)
full_gs_url = config_bucket + config_name
lowering_config_dir = os.path.join(WORKDIR, "configs", config_name)
print("Loading lowering config file from ", lowering_config_dir)
full_gs_url = config_bucket + config_name
download_public_file(full_gs_url, lowering_config_dir, True)
return lowering_config_dir

View File

@@ -132,57 +132,6 @@ p.add_argument(
"img2img.",
)
p.add_argument(
"--use_hiresfix",
type=bool,
default=False,
help="Use Hires Fix to do higher resolution images, while trying to "
"avoid the issues that come with it. This is accomplished by first "
"generating an image using txt2img, then running it through img2img.",
)
p.add_argument(
"--hiresfix_height",
type=int,
default=768,
choices=range(128, 769, 8),
help="The height of the Hires Fix image.",
)
p.add_argument(
"--hiresfix_width",
type=int,
default=768,
choices=range(128, 769, 8),
help="The width of the Hires Fix image.",
)
p.add_argument(
"--hiresfix_strength",
type=float,
default=0.6,
help="The denoising strength to apply for the Hires Fix.",
)
p.add_argument(
"--resample_type",
type=str,
default="Nearest Neighbor",
choices=[
"Lanczos",
"Nearest Neighbor",
"Bilinear",
"Bicubic",
"Adaptive",
"Antialias",
"Box",
"Affine",
"Cubic",
],
help="The resample type to use when resizing an image before being run "
"through stable diffusion.",
)
##############################################################################
# Stable Diffusion Training Params
##############################################################################
@@ -458,14 +407,6 @@ p.add_argument(
help="Specify your own huggingface authentication tokens for models like Llama2.",
)
p.add_argument(
"--device_allocator_heap_key",
type=str,
default="",
help="Specify heap key for device caching allocator."
"Expected form: max_allocation_size;max_allocation_capacity;max_free_allocation_count"
"Example: --device_allocator_heap_key='*;1gib' (will limit caching on device to 1 gigabyte)",
)
##############################################################################
# IREE - Vulkan supported flags
##############################################################################
@@ -578,20 +519,6 @@ p.add_argument(
"in shark importer. Does nothing if import_mlir is false (the default).",
)
p.add_argument(
"--compile_debug",
default=False,
action=argparse.BooleanOptionalAction,
help="Flag to toggle debug assert/verify flags for imported IR in the"
"iree-compiler. Default to false.",
)
p.add_argument(
"--iree_constant_folding",
default=True,
action=argparse.BooleanOptionalAction,
help="Controls constant folding in iree-compile for all SD models.",
)
##############################################################################
# Web UI flags
@@ -641,13 +568,6 @@ p.add_argument(
help="Flag for enabling rest API.",
)
p.add_argument(
"--debug",
default=False,
action=argparse.BooleanOptionalAction,
help="Flag for enabling debugging log in WebUI.",
)
p.add_argument(
"--output_gallery",
default=True,

View File

@@ -18,14 +18,14 @@ import tempfile
import torch
from safetensors.torch import load_file
from shark.shark_inference import SharkInference
from shark.shark_importer import import_with_fx, save_mlir
from shark.shark_importer import import_with_fx
from shark.iree_utils.vulkan_utils import (
set_iree_vulkan_runtime_flags,
get_vulkan_target_triple,
get_iree_vulkan_runtime_flags,
)
from shark.iree_utils.metal_utils import get_metal_target_triple
from shark.iree_utils.gpu_utils import get_cuda_sm_cc, get_iree_rocm_args
from shark.iree_utils.gpu_utils import get_cuda_sm_cc
from apps.stable_diffusion.src.utils.stable_args import args
from apps.stable_diffusion.src.utils.resources import opt_flags
from apps.stable_diffusion.src.utils.sd_annotation import sd_model_annotation
@@ -78,7 +78,7 @@ def _compile_module(shark_module, model_name, extra_args=[]):
)
)
path = shark_module.save_module(
os.getcwd(), model_name, extra_args, debug=args.compile_debug
os.getcwd(), model_name, extra_args
)
shark_module.load_module(path, extra_args=extra_args)
else:
@@ -154,8 +154,8 @@ def compile_through_fx(
f16_input_mask=f16_input_mask,
debug=debug,
model_name=extended_model_name,
save_dir=save_dir,
)
if use_tuned:
if "vae" in extended_model_name.split("_")[0]:
args.annotation_model = "vae"
@@ -168,14 +168,6 @@ def compile_through_fx(
mlir_module, extended_model_name, base_model_id
)
if not os.path.isdir(save_dir):
save_dir = ""
mlir_module = save_mlir(
mlir_module,
model_name=extended_model_name,
dir=save_dir,
)
shark_module = SharkInference(
mlir_module,
device=args.device if device is None else device,
@@ -187,22 +179,17 @@ def compile_through_fx(
mlir_module,
)
del mlir_module
gc.collect()
def set_iree_runtime_flags():
# TODO: This function should be device-agnostic and piped properly
# to general runtime driver init.
vulkan_runtime_flags = get_iree_vulkan_runtime_flags()
if args.enable_rgp:
vulkan_runtime_flags += [
f"--enable_rgp=true",
f"--vulkan_debug_utils=true",
]
if args.device_allocator_heap_key:
vulkan_runtime_flags += [
f"--device_allocator=caching:device_local={args.device_allocator_heap_key}",
]
set_iree_vulkan_runtime_flags(flags=vulkan_runtime_flags)
@@ -483,25 +470,12 @@ def get_available_devices():
set_iree_runtime_flags()
available_devices = []
from shark.iree_utils.vulkan_utils import (
get_all_vulkan_devices,
)
vulkaninfo_list = get_all_vulkan_devices()
vulkan_devices = []
id = 0
for device in vulkaninfo_list:
vulkan_devices.append(f"{device.strip()} => vulkan://{id}")
id += 1
if id != 0:
print(f"vulkan devices are available.")
vulkan_devices = get_devices_by_name("vulkan")
available_devices.extend(vulkan_devices)
metal_devices = get_devices_by_name("metal")
available_devices.extend(metal_devices)
cuda_devices = get_devices_by_name("cuda")
available_devices.extend(cuda_devices)
rocm_devices = get_devices_by_name("rocm")
available_devices.extend(rocm_devices)
cpu_device = get_devices_by_name("cpu-sync")
available_devices.extend(cpu_device)
cpu_device = get_devices_by_name("cpu-task")
@@ -525,15 +499,10 @@ def get_opt_flags(model, precision="fp16"):
iree_flags.append(
f"-iree-vulkan-target-triple={args.iree_vulkan_target_triple}"
)
if "rocm" in args.device:
rocm_args = get_iree_rocm_args()
iree_flags.extend(rocm_args)
print(iree_flags)
if args.iree_constant_folding == False:
iree_flags.append("--iree-opt-const-expr-hoisting=False")
iree_flags.append(
"--iree-codegen-linalg-max-constant-fold-elements=9223372036854775807"
)
# Disable bindings fusion to work with moltenVK.
if sys.platform == "darwin":
iree_flags.append("-iree-stream-fuse-binding=false")
if "default_compilation_flags" in opt_flags[model][is_tuned][precision]:
iree_flags += opt_flags[model][is_tuned][precision][
@@ -597,7 +566,7 @@ def preprocessCKPT(custom_weights, is_inpaint=False):
)
num_in_channels = 9 if is_inpaint else 4
pipe = download_from_original_stable_diffusion_ckpt(
checkpoint_path_or_dict=custom_weights,
checkpoint_path=custom_weights,
extract_ema=extract_ema,
from_safetensors=from_safetensors,
num_in_channels=num_in_channels,
@@ -847,8 +816,6 @@ def clear_all():
elif os.name == "unix":
shutil.rmtree(os.path.join(home, ".cache/AMD/VkCache"))
shutil.rmtree(os.path.join(home, ".local/shark_tank"))
if args.local_tank_cache != "":
shutil.rmtree(args.local_tank_cache)
def get_generated_imgs_path() -> Path:

View File

@@ -1,7 +1,6 @@
from multiprocessing import Process, freeze_support
import os
import sys
import logging
if sys.platform == "darwin":
# import before IREE to avoid torch-MLIR library issues
@@ -38,12 +37,10 @@ def launch_app(address):
height=height,
text_select=True,
)
webview.start(private_mode=False, storage_path=os.getcwd())
webview.start(private_mode=False)
if __name__ == "__main__":
if args.debug:
logging.basicConfig(level=logging.DEBUG)
# required to do multiprocessing in a pyinstaller freeze
freeze_support()
if args.api or "api" in args.ui.split(","):
@@ -118,8 +115,7 @@ if __name__ == "__main__":
txt2img_sendto_inpaint,
txt2img_sendto_outpaint,
txt2img_sendto_upscaler,
# h2ogpt_upload,
# h2ogpt_web,
h2ogpt_web,
img2img_web,
img2img_custom_model,
img2img_hf_model_id,
@@ -156,9 +152,8 @@ if __name__ == "__main__":
upscaler_sendto_img2img,
upscaler_sendto_inpaint,
upscaler_sendto_outpaint,
# lora_train_web,
# model_web,
# model_config_web,
lora_train_web,
model_web,
hf_models,
modelmanager_sendto_txt2img,
modelmanager_sendto_img2img,
@@ -216,15 +211,6 @@ if __name__ == "__main__":
css=dark_theme, analytics_enabled=False, title="Stable Diffusion"
) as sd_web:
with gr.Tabs() as tabs:
# NOTE: If adding, removing, or re-ordering tabs, make sure that they
# have a unique id that doesn't clash with any of the other tabs,
# and that the order in the code here is the order they should
# appear in the ui, as the id value doesn't determine the order.
# Where possible, avoid changing the id of any tab that is the
# destination of one of the 'send to' buttons. If you do have to change
# that id, make sure you update the relevant register_button_click calls
# further down with the new id.
with gr.TabItem(label="Text-to-Image", id=0):
txt2img_web.render()
with gr.TabItem(label="Image-to-Image", id=1):
@@ -250,22 +236,16 @@ if __name__ == "__main__":
upscaler_status,
]
)
# with gr.TabItem(label="Model Manager", id=6):
# model_web.render()
# with gr.TabItem(label="LoRA Training (Experimental)", id=7):
# lora_train_web.render()
with gr.TabItem(label="Chat Bot", id=8):
with gr.TabItem(label="Model Manager", id=6):
model_web.render()
with gr.TabItem(label="LoRA Training (Experimental)", id=8):
lora_train_web.render()
with gr.TabItem(label="Chat Bot (Experimental)", id=7):
stablelm_chat.render()
# with gr.TabItem(
# label="Generate Sharding Config (Experimental)", id=9
# ):
# model_config_web.render()
with gr.TabItem(label="MultiModal (Experimental)", id=10):
with gr.TabItem(label="MultiModal (Experimental)", id=9):
minigpt4_web.render()
# with gr.TabItem(label="DocuChat Upload", id=11):
# h2ogpt_upload.render()
# with gr.TabItem(label="DocuChat(Experimental)", id=12):
# h2ogpt_web.render()
with gr.TabItem(label="DocuChat(Experimental)", id=10):
h2ogpt_web.render()
# send to buttons
register_button_click(

View File

@@ -78,7 +78,7 @@ from apps.stable_diffusion.web.ui.stablelm_ui import (
stablelm_chat,
llm_chat_api,
)
from apps.stable_diffusion.web.ui.generate_config import model_config_web
from apps.stable_diffusion.web.ui.h2ogpt import h2ogpt_web
from apps.stable_diffusion.web.ui.minigpt4_ui import minigpt4_web
from apps.stable_diffusion.web.ui.outputgallery_ui import (
outputgallery_web,

View File

@@ -1,41 +0,0 @@
import gradio as gr
import torch
from transformers import AutoTokenizer
from apps.language_models.src.model_wrappers.vicuna_model import CombinedModel
from shark.shark_generate_model_config import GenerateConfigFile
def get_model_config():
hf_model_path = "TheBloke/vicuna-7B-1.1-HF"
tokenizer = AutoTokenizer.from_pretrained(hf_model_path, use_fast=False)
compilation_prompt = "".join(["0" for _ in range(17)])
compilation_input_ids = tokenizer(
compilation_prompt,
return_tensors="pt",
).input_ids
compilation_input_ids = torch.tensor(compilation_input_ids).reshape(
[1, 19]
)
firstVicunaCompileInput = (compilation_input_ids,)
model = CombinedModel()
c = GenerateConfigFile(model, 1, ["gpu_id"], firstVicunaCompileInput)
return c.split_into_layers()
with gr.Blocks() as model_config_web:
with gr.Row():
hf_models = gr.Dropdown(
label="Model List",
choices=["Vicuna"],
value="Vicuna",
visible=True,
)
get_model_config_btn = gr.Button(value="Get Model Config")
json_view = gr.JSON()
get_model_config_btn.click(
fn=get_model_config,
inputs=[],
outputs=[json_view],
)

View File

@@ -12,10 +12,6 @@ from apps.language_models.langchain.enums import (
LangChainAction,
)
import apps.language_models.langchain.gen as gen
from gpt_langchain import (
path_to_docs,
create_or_update_db,
)
from apps.stable_diffusion.src import args
@@ -37,15 +33,8 @@ start_message = """
def create_prompt(history):
system_message = start_message
for item in history:
print("His item: ", item)
conversation = "<|endoftext|>".join(
[
"<|endoftext|><|answer|>".join([item[0], item[1]])
for item in history
]
)
conversation = "".join(["".join([item[0], item[1]]) for item in history])
msg = system_message + conversation
msg = msg.strip()
@@ -55,12 +44,10 @@ def create_prompt(history):
def chat(curr_system_message, history, device, precision):
args.run_docuchat_web = True
global h2ogpt_model
global sharkModel
global h2ogpt_tokenizer
global model_state
global langchain
global userpath_selector
from apps.language_models.langchain.h2oai_pipeline import generate_token
if h2ogpt_model == 0:
if "cuda" in device:
@@ -115,14 +102,9 @@ def chat(curr_system_message, history, device, precision):
prompt_type=None,
prompt_dict=None,
)
from apps.language_models.langchain.h2oai_pipeline import (
H2OGPTSHARKModel,
)
sharkModel = H2OGPTSHARKModel()
prompt = create_prompt(history)
output_dict = langchain.evaluate(
output = langchain.evaluate(
model_state=model_state,
my_db_state=None,
instruction=prompt,
@@ -182,22 +164,14 @@ def chat(curr_system_message, history, device, precision):
model_lock=True,
user_path=userpath_selector.value,
)
output = generate_token(sharkModel, **output_dict)
for partial_text in output:
history[-1][1] = partial_text
history[-1][1] = partial_text["response"]
yield history
return history
userpath_selector = gr.Textbox(
label="Document Directory",
value=str(os.path.abspath("apps/language_models/langchain/user_path/")),
interactive=True,
container=True,
)
with gr.Blocks(title="DocuChat") as h2ogpt_web:
with gr.Blocks(title="H2OGPT") as h2ogpt_web:
with gr.Row():
supported_devices = available_devices
enabled = len(supported_devices) > 0
@@ -212,7 +186,6 @@ with gr.Blocks(title="DocuChat") as h2ogpt_web:
else "Only CUDA Supported for now",
choices=supported_devices,
interactive=enabled,
allow_custom_value=True,
)
precision = gr.Radio(
label="Precision",
@@ -225,6 +198,14 @@ with gr.Blocks(title="DocuChat") as h2ogpt_web:
],
visible=True,
)
userpath_selector = gr.Textbox(
label="Document Directory",
value=str(
os.path.abspath("apps/language_models/langchain/user_path/")
),
interactive=True,
container=True,
)
chatbot = gr.Chatbot(height=500)
with gr.Row():
with gr.Column():
@@ -268,100 +249,3 @@ with gr.Blocks(title="DocuChat") as h2ogpt_web:
queue=False,
)
clear.click(lambda: None, None, [chatbot], queue=False)
with gr.Blocks(title="DocuChat Upload") as h2ogpt_upload:
import pathlib
upload_path = None
database = None
database_directory = os.path.abspath(
"apps/language_models/langchain/db_path/"
)
def read_path():
global upload_path
filenames = [
[f]
for f in os.listdir(upload_path)
if os.path.isfile(os.path.join(upload_path, f))
]
filenames.sort()
return filenames
def upload_file(f):
names = []
for tmpfile in f:
name = tmpfile.name.split("/")[-1]
basename = os.path.join(upload_path, name)
with open(basename, "wb") as w:
with open(tmpfile.name, "rb") as r:
w.write(r.read())
update_or_create_db()
return read_path()
def update_userpath(newpath):
global upload_path
upload_path = newpath
pathlib.Path(upload_path).mkdir(parents=True, exist_ok=True)
return read_path()
def update_or_create_db():
global database
global upload_path
sources = path_to_docs(
upload_path,
verbose=True,
fail_any_exception=False,
n_jobs=-1,
chunk=True,
chunk_size=512,
url=None,
enable_captions=False,
captions_model=None,
caption_loader=None,
enable_ocr=False,
)
pathlib.Path(database_directory).mkdir(parents=True, exist_ok=True)
database = create_or_update_db(
"chroma",
database_directory,
"UserData",
sources,
False,
True,
True,
"sentence-transformers/all-MiniLM-L6-v2",
)
def first_run():
global database
if database is None:
update_or_create_db()
update_userpath(
os.path.abspath("apps/language_models/langchain/user_path/")
)
h2ogpt_upload.load(fn=first_run)
h2ogpt_web.load(fn=first_run)
with gr.Column():
text = gr.DataFrame(
col_count=(1, "fixed"),
type="array",
label="Documents",
value=read_path(),
)
with gr.Row():
upload = gr.UploadButton(
label="Upload documents",
file_count="multiple",
)
upload.upload(fn=upload_file, inputs=upload, outputs=text)
userpath_selector.render()
userpath_selector.input(
fn=update_userpath, inputs=userpath_selector, outputs=text
).then(fn=update_or_create_db)

View File

@@ -3,7 +3,6 @@ import torch
import time
import gradio as gr
import PIL
from math import ceil
from PIL import Image
import base64
from io import BytesIO
@@ -68,7 +67,6 @@ def img2img_inf(
lora_hf_id: str,
ondemand: bool,
repeatable_seeds: bool,
resample_type: str,
):
from apps.stable_diffusion.web.ui.utils import (
get_custom_model_pathfile,
@@ -247,7 +245,7 @@ def img2img_inf(
batch_size,
height,
width,
ceil(steps / strength),
steps,
strength,
guidance_scale,
seeds[current_batch],
@@ -257,7 +255,6 @@ def img2img_inf(
cpu_scheduling,
args.max_embeddings_multiples,
use_stencil=use_stencil,
resample_type=resample_type,
)
total_time = time.time() - start_time
text_output = get_generation_text_info(
@@ -351,7 +348,6 @@ def img2img_api(
lora_hf_id="",
ondemand=False,
repeatable_seeds=False,
resample_type="Lanczos",
)
# Converts generator type to subscriptable
@@ -396,7 +392,6 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
choices=["None"]
+ get_custom_model_files()
+ predefined_models,
allow_custom_value=True,
)
img2img_hf_model_id = gr.Textbox(
elem_id="hf_model_id",
@@ -422,7 +417,6 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
if args.custom_vae
else "None",
choices=["None"] + get_custom_model_files("vae"),
allow_custom_value=True,
)
with gr.Group(elem_id="prompt_box_outer"):
@@ -438,7 +432,7 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
lines=2,
elem_id="negative_prompt_box",
)
# TODO: make this import image prompt info if it exists
img2img_init_image = gr.Image(
label="Input Image",
source="upload",
@@ -454,7 +448,6 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
label="Stencil model",
value="None",
choices=["None", "canny", "openpose", "scribble"],
allow_custom_value=True,
)
def show_canvas(choice):
@@ -515,7 +508,6 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
).replace("\\", "\n\\")
i2i_lora_info = f"LoRA Path: {i2i_lora_info}"
lora_weights = gr.Dropdown(
allow_custom_value=True,
label=f"Standalone LoRA Weights",
info=i2i_lora_info,
elem_id="lora_weights",
@@ -539,7 +531,6 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
label="Scheduler",
value="EulerDiscrete",
choices=scheduler_list_cpu_only,
allow_custom_value=True,
)
with gr.Group():
save_metadata_to_png = gr.Checkbox(
@@ -559,6 +550,15 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
width = gr.Slider(
384, 768, value=args.width, step=8, label="Width"
)
precision = gr.Radio(
label="Precision",
value=args.precision,
choices=[
"fp16",
"fp32",
],
visible=True,
)
max_length = gr.Radio(
label="Max Length",
value=args.max_length,
@@ -581,36 +581,11 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
step=0.01,
label="Denoising Strength",
)
resample_type = gr.Dropdown(
value=args.resample_type,
choices=[
"Lanczos",
"Nearest Neighbor",
"Bilinear",
"Bicubic",
"Adaptive",
"Antialias",
"Box",
"Affine",
"Cubic",
],
label="Resample Type",
allow_custom_value=True,
)
ondemand = gr.Checkbox(
value=args.ondemand,
label="Low VRAM",
interactive=True,
)
precision = gr.Radio(
label="Precision",
value=args.precision,
choices=[
"fp16",
"fp32",
],
visible=True,
)
with gr.Row():
with gr.Column(scale=3):
guidance_scale = gr.Slider(
@@ -654,7 +629,6 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
label="Device",
value=available_devices[0],
choices=available_devices,
allow_custom_value=True,
)
with gr.Row():
random_seed = gr.Button("Randomize Seed")
@@ -721,7 +695,6 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
lora_hf_id,
ondemand,
repeatable_seeds,
resample_type,
],
outputs=[img2img_gallery, std_output, img2img_status],
show_progress="minimal" if args.progress_bar else "none",

View File

@@ -344,7 +344,6 @@ with gr.Blocks(title="Inpainting") as inpaint_web:
custom_checkpoint_type="inpainting"
)
+ predefined_paint_models,
allow_custom_value=True,
)
inpaint_hf_model_id = gr.Textbox(
elem_id="hf_model_id",
@@ -370,7 +369,6 @@ with gr.Blocks(title="Inpainting") as inpaint_web:
if args.custom_vae
else "None",
choices=["None"] + get_custom_model_files("vae"),
allow_custom_value=True,
)
with gr.Group(elem_id="prompt_box_outer"):
@@ -408,7 +406,6 @@ with gr.Blocks(title="Inpainting") as inpaint_web:
elem_id="lora_weights",
value="None",
choices=["None"] + get_custom_model_files("lora"),
allow_custom_value=True,
)
lora_hf_id = gr.Textbox(
elem_id="lora_hf_id",
@@ -427,7 +424,6 @@ with gr.Blocks(title="Inpainting") as inpaint_web:
label="Scheduler",
value="EulerDiscrete",
choices=scheduler_list_cpu_only,
allow_custom_value=True,
)
with gr.Group():
save_metadata_to_png = gr.Checkbox(
@@ -531,7 +527,6 @@ with gr.Blocks(title="Inpainting") as inpaint_web:
label="Device",
value=available_devices[0],
choices=available_devices,
allow_custom_value=True,
)
with gr.Row():
random_seed = gr.Button("Randomize Seed")

View File

@@ -50,7 +50,6 @@ with gr.Blocks(title="Lora Training") as lora_train_web:
choices=["None"]
+ get_custom_model_files()
+ predefined_models,
allow_custom_value=True,
)
hf_model_id = gr.Textbox(
elem_id="hf_model_id",
@@ -74,7 +73,6 @@ with gr.Blocks(title="Lora Training") as lora_train_web:
elem_id="lora_weights",
value="None",
choices=["None"] + get_custom_model_files("lora"),
allow_custom_value=True,
)
lora_hf_id = gr.Textbox(
elem_id="lora_hf_id",
@@ -107,7 +105,6 @@ with gr.Blocks(title="Lora Training") as lora_train_web:
label="Scheduler",
value=args.scheduler,
choices=scheduler_list,
allow_custom_value=True,
)
with gr.Row():
height = gr.Slider(
@@ -180,7 +177,6 @@ with gr.Blocks(title="Lora Training") as lora_train_web:
label="Device",
value=available_devices[0],
choices=available_devices,
allow_custom_value=True,
)
with gr.Row():
with gr.Column(scale=2):

View File

@@ -109,7 +109,7 @@ with gr.Blocks() as minigpt4_web:
gr.Markdown(description)
with gr.Row():
with gr.Column():
with gr.Column(scale=0.5):
image = gr.Image(type="pil")
upload_button = gr.Button(
value="Upload & Start Chat",
@@ -143,7 +143,6 @@ with gr.Blocks() as minigpt4_web:
# else "Only CUDA Supported for now",
choices=["cuda"],
interactive=False,
allow_custom_value=True,
)
with gr.Column():

View File

@@ -98,7 +98,6 @@ with gr.Blocks() as model_web:
choices=None,
value=None,
visible=False,
allow_custom_value=True,
)
# TODO: select and SendTo
civit_models = gr.Gallery(

View File

@@ -351,7 +351,6 @@ with gr.Blocks(title="Outpainting") as outpaint_web:
custom_checkpoint_type="inpainting"
)
+ predefined_paint_models,
allow_custom_value=True,
)
outpaint_hf_model_id = gr.Textbox(
elem_id="hf_model_id",
@@ -377,7 +376,6 @@ with gr.Blocks(title="Outpainting") as outpaint_web:
if args.custom_vae
else "None",
choices=["None"] + get_custom_model_files("vae"),
allow_custom_value=True,
)
with gr.Group(elem_id="prompt_box_outer"):
@@ -413,7 +411,6 @@ with gr.Blocks(title="Outpainting") as outpaint_web:
elem_id="lora_weights",
value="None",
choices=["None"] + get_custom_model_files("lora"),
allow_custom_value=True,
)
lora_hf_id = gr.Textbox(
elem_id="lora_hf_id",
@@ -432,7 +429,6 @@ with gr.Blocks(title="Outpainting") as outpaint_web:
label="Scheduler",
value="EulerDiscrete",
choices=scheduler_list_cpu_only,
allow_custom_value=True,
)
with gr.Group():
save_metadata_to_png = gr.Checkbox(
@@ -559,7 +555,6 @@ with gr.Blocks(title="Outpainting") as outpaint_web:
label="Device",
value=available_devices[0],
choices=available_devices,
allow_custom_value=True,
)
with gr.Row():
random_seed = gr.Button("Randomize Seed")

View File

@@ -109,7 +109,6 @@ with gr.Blocks() as outputgallery_web:
value="",
interactive=True,
elem_classes="dropdown_no_container",
allow_custom_value=True,
)
with gr.Column(
scale=1,

View File

@@ -7,8 +7,6 @@ from transformers import (
)
from apps.stable_diffusion.web.ui.utils import available_devices
from datetime import datetime as dt
import json
import sys
def user(message, history):
@@ -24,9 +22,11 @@ past_key_values = None
model_map = {
"llama2_7b": "meta-llama/Llama-2-7b-chat-hf",
"llama2_13b": "meta-llama/Llama-2-13b-chat-hf",
"llama2_70b": "meta-llama/Llama-2-70b-chat-hf",
"codegen": "Salesforce/codegen25-7b-multi",
"vicuna1p3": "lmsys/vicuna-7b-v1.3",
"vicuna": "TheBloke/vicuna-7B-1.1-HF",
"StableLM": "stabilityai/stablelm-tuned-alpha-3b",
}
# NOTE: Each `model_name` should have its own start message
@@ -40,15 +40,6 @@ start_message = {
"explain why instead of answering something not correct. If you don't know the "
"answer to a question, please don't share false information."
),
"llama2_13b": (
"System: You are a helpful, respectful and honest assistant. Always answer "
"as helpfully as possible, while being safe. Your answers should not "
"include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal "
"content. Please ensure that your responses are socially unbiased and positive "
"in nature. If a question does not make any sense, or is not factually coherent, "
"explain why instead of answering something not correct. If you don't know the "
"answer to a question, please don't share false information."
),
"llama2_70b": (
"System: You are a helpful, respectful and honest assistant. Always answer "
"as helpfully as possible, while being safe. Your answers should not "
@@ -58,41 +49,54 @@ start_message = {
"explain why instead of answering something not correct. If you don't know the "
"answer to a question, please don't share false information."
),
"StableLM": (
"<|SYSTEM|># StableLM Tuned (Alpha version)"
"\n- StableLM is a helpful and harmless open-source AI language model "
"developed by StabilityAI."
"\n- StableLM is excited to be able to help the user, but will refuse "
"to do anything that could be considered harmful to the user."
"\n- StableLM is more than just an information source, StableLM is also "
"able to write poetry, short stories, and make jokes."
"\n- StableLM will refuse to participate in anything that "
"could harm a human."
),
"vicuna": (
"A chat between a curious user and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the user's "
"questions.\n"
),
"vicuna1p3": (
"A chat between a curious user and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the user's "
"questions.\n"
),
"codegen": "",
}
def create_prompt(model_name, history, prompt_prefix):
system_message = ""
if prompt_prefix:
system_message = start_message[model_name]
def create_prompt(model_name, history):
system_message = start_message[model_name]
if "llama2" in model_name:
B_INST, E_INST = "[INST]", "[/INST]"
B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
conversation = "".join(
[f"{B_INST} {item[0]} {E_INST} {item[1]} " for item in history[1:]]
)
msg = f"{B_INST} {B_SYS} {system_message} {E_SYS} {history[0][0]} {E_INST} {history[0][1]} {conversation}"
elif model_name in ["vicuna"]:
if model_name in [
"StableLM",
"vicuna",
"vicuna1p3",
"llama2_7b",
"llama2_70b",
]:
conversation = "".join(
[
"".join(["<|USER|>" + item[0], "<|ASSISTANT|>" + item[1]])
for item in history
]
)
msg = system_message + conversation
msg = msg.strip()
else:
conversation = "".join(
["".join([item[0], item[1]]) for item in history]
)
msg = system_message + conversation
msg = msg.strip()
msg = system_message + conversation
msg = msg.strip()
return msg
@@ -101,178 +105,84 @@ def set_vicuna_model(model):
vicuna_model = model
def get_default_config():
import torch
from transformers import AutoTokenizer
hf_model_path = "TheBloke/vicuna-7B-1.1-HF"
tokenizer = AutoTokenizer.from_pretrained(hf_model_path, use_fast=False)
compilation_prompt = "".join(["0" for _ in range(17)])
compilation_input_ids = tokenizer(
compilation_prompt,
return_tensors="pt",
).input_ids
compilation_input_ids = torch.tensor(compilation_input_ids).reshape(
[1, 19]
)
firstVicunaCompileInput = (compilation_input_ids,)
from apps.language_models.src.model_wrappers.vicuna_model import (
CombinedModel,
)
from shark.shark_generate_model_config import GenerateConfigFile
model = CombinedModel()
c = GenerateConfigFile(model, 1, ["gpu_id"], firstVicunaCompileInput)
c.split_into_layers()
model_vmfb_key = ""
# TODO: Make chat reusable for UI and API
def chat(
prompt_prefix,
history,
model,
device,
precision,
download_vmfb,
config_file,
cli=False,
progress=gr.Progress(),
):
def chat(curr_system_message, history, model, device, precision, cli=True):
global past_key_values
global model_vmfb_key
global vicuna_model
device_id = None
model_name, model_path = list(map(str.strip, model.split("=>")))
if "cuda" in device:
device = "cuda"
elif "sync" in device:
device = "cpu-sync"
elif "task" in device:
device = "cpu-task"
elif "vulkan" in device:
device_id = int(device.split("://")[1])
device = "vulkan"
elif "rocm" in device:
device = "rocm"
else:
print("unrecognized device")
from apps.language_models.scripts.vicuna import ShardedVicuna
from apps.language_models.scripts.vicuna import UnshardedVicuna
from apps.stable_diffusion.src import args
new_model_vmfb_key = f"{model_name}#{model_path}#{device}#{device_id}#{precision}#{download_vmfb}"
if vicuna_model is None or new_model_vmfb_key != model_vmfb_key:
model_vmfb_key = new_model_vmfb_key
max_toks = 128 if model_name == "codegen" else 512
# get iree flags that need to be overridden, from commandline args
_extra_args = []
# vulkan target triple
vulkan_target_triple = args.iree_vulkan_target_triple
from shark.iree_utils.vulkan_utils import (
get_all_vulkan_devices,
get_vulkan_target_triple,
if model_name in [
"vicuna",
"vicuna1p3",
"codegen",
"llama2_7b",
"llama2_70b",
]:
from apps.language_models.scripts.vicuna import (
UnshardedVicuna,
)
from apps.stable_diffusion.src import args
if device == "vulkan":
vulkaninfo_list = get_all_vulkan_devices()
if vulkan_target_triple == "":
# We already have the device_id extracted via WebUI, so we directly use
# that to find the target triple.
vulkan_target_triple = get_vulkan_target_triple(
vulkaninfo_list[device_id]
)
_extra_args.append(
f"-iree-vulkan-target-triple={vulkan_target_triple}"
)
if "rdna" in vulkan_target_triple:
flags_to_add = [
"--iree-spirv-index-bits=64",
]
_extra_args = _extra_args + flags_to_add
if vicuna_model == 0:
if "cuda" in device:
device = "cuda"
elif "sync" in device:
device = "cpu-sync"
elif "task" in device:
device = "cpu-task"
elif "vulkan" in device:
device = "vulkan"
else:
print("unrecognized device")
if device_id is None:
id = 0
for device in vulkaninfo_list:
target_triple = get_vulkan_target_triple(
vulkaninfo_list[id]
)
if target_triple == vulkan_target_triple:
device_id = id
break
id += 1
assert (
device_id
), f"no vulkan hardware for target-triple '{vulkan_target_triple}' exists"
print(f"Will use target triple : {vulkan_target_triple}")
if model_name == "vicuna4":
vicuna_model = ShardedVicuna(
model_name,
hf_model_path=model_path,
device=device,
precision=precision,
max_num_tokens=max_toks,
compressed=True,
extra_args_cmd=_extra_args,
)
else:
# if config_file is None:
max_toks = 128 if model_name == "codegen" else 512
vicuna_model = UnshardedVicuna(
model_name,
hf_model_path=model_path,
hf_auth_token=args.hf_auth_token,
device=device,
vulkan_target_triple=vulkan_target_triple,
precision=precision,
max_num_tokens=max_toks,
download_vmfb=download_vmfb,
load_mlir_from_shark_tank=True,
extra_args_cmd=_extra_args,
device_id=device_id,
)
prompt = create_prompt(model_name, history)
if vicuna_model is None:
sys.exit("Unable to instantiate the model object, exiting.")
for partial_text in vicuna_model.generate(prompt, cli=cli):
history[-1][1] = partial_text
yield history
prompt = create_prompt(model_name, history, prompt_prefix)
return history
# else Model is StableLM
global sharkModel
from apps.language_models.src.pipelines.stablelm_pipeline import (
SharkStableLM,
)
if sharkModel == 0:
# max_new_tokens=512
shark_slm = SharkStableLM(
model_name
) # pass elements from UI as required
# Construct the input message string for the model by concatenating the
# current system message and conversation history
if len(curr_system_message.split()) > 160:
print("clearing context")
prompt = create_prompt(model_name, history)
generate_kwargs = dict(prompt=prompt)
words_list = shark_slm.generate(**generate_kwargs)
partial_text = ""
token_count = 0
total_time_ms = 0.001 # In order to avoid divide by zero error
prefill_time = 0
is_first = True
for text, msg, exec_time in progress.tqdm(
vicuna_model.generate(prompt, cli=cli),
desc="generating response",
):
if msg is None:
if is_first:
prefill_time = exec_time
is_first = False
else:
total_time_ms += exec_time
token_count += 1
partial_text += text + " "
history[-1][1] = partial_text
yield history, f"Prefill: {prefill_time:.2f}"
elif "formatted" in msg:
history[-1][1] = text
tokens_per_sec = (token_count / total_time_ms) * 1000
yield history, f"Prefill: {prefill_time:.2f} seconds\n Decode: {tokens_per_sec:.2f} tokens/sec"
else:
sys.exit(
"unexpected message from the vicuna generate call, exiting."
)
return history, ""
for new_text in words_list:
print(new_text)
partial_text += new_text
history[-1][1] = partial_text
# Yield an empty string to clean up the message textbox and the updated
# conversation history
yield history
return words_list
def llm_chat_api(InputData: dict):
@@ -308,7 +218,6 @@ def llm_chat_api(InputData: dict):
UnshardedVicuna,
)
device_id = None
if vicuna_model == 0:
if "cuda" in device:
device = "cuda"
@@ -317,7 +226,6 @@ def llm_chat_api(InputData: dict):
elif "task" in device:
device = "cpu-task"
elif "vulkan" in device:
device_id = int(device.split("://")[1])
device = "vulkan"
else:
print("unrecognized device")
@@ -328,9 +236,6 @@ def llm_chat_api(InputData: dict):
device=device,
precision=precision,
max_num_tokens=max_toks,
download_vmfb=True,
load_mlir_from_shark_tank=True,
device_id=device_id,
)
# TODO: add role dict for different models
@@ -395,13 +300,13 @@ with gr.Blocks(title="Chatbot") as stablelm_chat:
label="Select Model",
value=model_choices[0],
choices=model_choices,
allow_custom_value=True,
)
supported_devices = available_devices
enabled = len(supported_devices) > 0
# show cpu-task device first in list for chatbot
supported_devices = supported_devices[-1:] + supported_devices[:-1]
supported_devices = [x for x in supported_devices if "sync" not in x]
print(supported_devices)
device = gr.Dropdown(
label="Device",
value=supported_devices[0]
@@ -409,39 +314,23 @@ with gr.Blocks(title="Chatbot") as stablelm_chat:
else "Only CUDA Supported for now",
choices=supported_devices,
interactive=enabled,
allow_custom_value=True,
# multiselect=True,
)
precision = gr.Radio(
label="Precision",
value="int4",
value="fp16",
choices=[
"int4",
"int8",
"fp16",
"fp32",
],
visible=False,
visible=True,
)
tokens_time = gr.Textbox(label="Tokens generated per second")
with gr.Column():
download_vmfb = gr.Checkbox(
label="Download vmfb from Shark tank if available",
value=True,
interactive=True,
)
prompt_prefix = gr.Checkbox(
label="Add System Prompt",
value=False,
interactive=True,
)
with gr.Row(visible=False):
with gr.Row():
with gr.Group():
config_file = gr.File(
label="Upload sharding configuration", visible=False
)
json_view_button = gr.Button(label="View as JSON", visible=False)
json_view = gr.JSON(interactive=True, visible=False)
config_file = gr.File(label="Upload sharding configuration")
json_view_button = gr.Button("View as JSON")
json_view = gr.JSON()
json_view_button.click(
fn=view_json_file, inputs=[config_file], outputs=[json_view]
)
@@ -460,47 +349,24 @@ with gr.Blocks(title="Chatbot") as stablelm_chat:
submit = gr.Button("Submit", interactive=enabled)
stop = gr.Button("Stop", interactive=enabled)
clear = gr.Button("Clear", interactive=enabled)
system_msg = gr.Textbox(
start_message, label="System Message", interactive=False, visible=False
)
submit_event = msg.submit(
fn=user,
inputs=[msg, chatbot],
outputs=[msg, chatbot],
show_progress=False,
queue=False,
fn=user, inputs=[msg, chatbot], outputs=[msg, chatbot], queue=False
).then(
fn=chat,
inputs=[
prompt_prefix,
chatbot,
model,
device,
precision,
download_vmfb,
config_file,
],
outputs=[chatbot, tokens_time],
show_progress=False,
inputs=[system_msg, chatbot, model, device, precision],
outputs=[chatbot],
queue=True,
)
submit_click_event = submit.click(
fn=user,
inputs=[msg, chatbot],
outputs=[msg, chatbot],
show_progress=False,
queue=False,
fn=user, inputs=[msg, chatbot], outputs=[msg, chatbot], queue=False
).then(
fn=chat,
inputs=[
prompt_prefix,
chatbot,
model,
device,
precision,
download_vmfb,
config_file,
],
outputs=[chatbot, tokens_time],
show_progress=False,
inputs=[system_msg, chatbot, model, device, precision],
outputs=[chatbot],
queue=True,
)
stop.click(

View File

@@ -4,7 +4,6 @@ import time
import sys
import gradio as gr
from PIL import Image
from math import ceil
import base64
from io import BytesIO
from fastapi.exceptions import HTTPException
@@ -27,7 +26,6 @@ from apps.stable_diffusion.src import (
utils,
save_output_img,
prompt_examples,
Image2ImagePipeline,
)
from apps.stable_diffusion.src.utils import (
get_generated_imgs_path,
@@ -64,11 +62,6 @@ def txt2img_inf(
lora_hf_id: str,
ondemand: bool,
repeatable_seeds: bool,
use_hiresfix: bool,
hiresfix_height: int,
hiresfix_width: int,
hiresfix_strength: float,
resample_type: str,
):
from apps.stable_diffusion.web.ui.utils import (
get_custom_model_pathfile,
@@ -207,81 +200,6 @@ def txt2img_inf(
cpu_scheduling,
args.max_embeddings_multiples,
)
# TODO: allow user to save original image
# TODO: add option to let user keep both pipelines loaded, and unload
# either at will
# TODO: add custom step value slider
# TODO: add option to use secondary model for the img2img pass
if use_hiresfix is True:
new_config_obj = Config(
"img2img",
args.hf_model_id,
args.ckpt_loc,
args.custom_vae,
precision,
1,
max_length,
height,
width,
device,
use_lora=args.use_lora,
use_stencil="None",
ondemand=ondemand,
)
global_obj.clear_cache()
global_obj.set_cfg_obj(new_config_obj)
set_init_device_flags()
model_id = (
args.hf_model_id
if args.hf_model_id
else "stabilityai/stable-diffusion-2-1-base"
)
global_obj.set_schedulers(get_schedulers(model_id))
scheduler_obj = global_obj.get_scheduler(args.scheduler)
global_obj.set_sd_obj(
Image2ImagePipeline.from_pretrained(
scheduler_obj,
args.import_mlir,
args.hf_model_id,
args.ckpt_loc,
args.custom_vae,
args.precision,
args.max_length,
1,
hiresfix_height,
hiresfix_width,
args.use_base_vae,
args.use_tuned,
low_cpu_mem_usage=args.low_cpu_mem_usage,
debug=args.import_debug if args.import_mlir else False,
use_lora=args.use_lora,
ondemand=args.ondemand,
)
)
global_obj.set_sd_scheduler(args.scheduler)
out_imgs = global_obj.get_sd_obj().generate_images(
prompt,
negative_prompt,
out_imgs[0],
batch_size,
hiresfix_height,
hiresfix_width,
ceil(steps / hiresfix_strength),
hiresfix_strength,
guidance_scale,
seeds[current_batch],
args.max_length,
dtype,
args.use_base_vae,
cpu_scheduling,
args.max_embeddings_multiples,
use_stencil="None",
resample_type=resample_type,
)
total_time = time.time() - start_time
text_output = get_generation_text_info(
seeds[: current_batch + 1], device
@@ -353,11 +271,6 @@ def txt2img_api(
lora_hf_id="",
ondemand=False,
repeatable_seeds=False,
use_hiresfix=False,
hiresfix_height=512,
hiresfix_width=512,
hiresfix_strength=0.6,
resample_type="Nearest Neighbor",
)
# Convert Generator to Subscriptable
@@ -406,7 +319,6 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
choices=["None"]
+ get_custom_model_files()
+ predefined_models,
allow_custom_value=True,
)
txt2img_hf_model_id = gr.Textbox(
elem_id="hf_model_id",
@@ -431,7 +343,6 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
else "None",
choices=["None"]
+ get_custom_model_files("vae"),
allow_custom_value=True,
)
with gr.Column(scale=1, min_width=170):
txt2img_png_info_img = gr.Image(
@@ -468,7 +379,6 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
elem_id="lora_weights",
value="None",
choices=["None"] + get_custom_model_files("lora"),
allow_custom_value=True,
)
lora_hf_id = gr.Textbox(
elem_id="lora_hf_id",
@@ -487,7 +397,6 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
label="Scheduler",
value=args.scheduler,
choices=scheduler_list,
allow_custom_value=True,
)
with gr.Column():
save_metadata_to_png = gr.Checkbox(
@@ -551,50 +460,6 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
label="Low VRAM",
interactive=True,
)
with gr.Group():
with gr.Row():
use_hiresfix = gr.Checkbox(
value=args.use_hiresfix,
label="Use Hires Fix",
interactive=True,
)
resample_type = gr.Dropdown(
value=args.resample_type,
choices=[
"Lanczos",
"Nearest Neighbor",
"Bilinear",
"Bicubic",
"Adaptive",
"Antialias",
"Box",
"Affine",
"Cubic",
],
label="Resample Type",
allow_custom_value=True,
)
hiresfix_height = gr.Slider(
384,
768,
value=args.hiresfix_height,
step=8,
label="Hires Fix Height",
)
hiresfix_width = gr.Slider(
384,
768,
value=args.hiresfix_width,
step=8,
label="Hires Fix Width",
)
hiresfix_strength = gr.Slider(
0,
1,
value=args.hiresfix_strength,
step=0.01,
label="Hires Fix Denoising Strength",
)
with gr.Row():
with gr.Column(scale=3):
batch_count = gr.Slider(
@@ -629,8 +494,17 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
label="Device",
value=available_devices[0],
choices=available_devices,
allow_custom_value=True,
)
with gr.Row():
random_seed = gr.Button("Randomize Seed")
random_seed.click(
lambda: -1,
inputs=[],
outputs=[seed],
queue=False,
)
stop_batch = gr.Button("Stop Batch")
stable_diffusion = gr.Button("Generate Image(s)")
with gr.Accordion(label="Prompt Examples!", open=False):
ex = gr.Examples(
examples=prompt_examples,
@@ -656,18 +530,6 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
show_label=False,
)
txt2img_status = gr.Textbox(visible=False)
with gr.Row():
stable_diffusion = gr.Button("Generate Image(s)")
random_seed = gr.Button("Randomize Seed")
random_seed.click(
lambda: -1,
inputs=[],
outputs=[seed],
queue=False,
)
stop_batch = gr.Button("Stop Batch")
with gr.Row():
blank_thing_for_row = None
with gr.Row():
txt2img_sendto_img2img = gr.Button(value="SendTo Img2Img")
txt2img_sendto_inpaint = gr.Button(value="SendTo Inpaint")
@@ -703,11 +565,6 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
lora_hf_id,
ondemand,
repeatable_seeds,
use_hiresfix,
hiresfix_height,
hiresfix_width,
hiresfix_strength,
resample_type,
],
outputs=[txt2img_gallery, std_output, txt2img_status],
show_progress="minimal" if args.progress_bar else "none",

View File

@@ -365,7 +365,6 @@ with gr.Blocks(title="Upscaler") as upscaler_web:
custom_checkpoint_type="upscaler"
)
+ predefined_upscaler_models,
allow_custom_value=True,
)
upscaler_hf_model_id = gr.Textbox(
elem_id="hf_model_id",
@@ -391,7 +390,6 @@ with gr.Blocks(title="Upscaler") as upscaler_web:
if args.custom_vae
else "None",
choices=["None"] + get_custom_model_files("vae"),
allow_custom_value=True,
)
with gr.Group(elem_id="prompt_box_outer"):
@@ -427,7 +425,6 @@ with gr.Blocks(title="Upscaler") as upscaler_web:
elem_id="lora_weights",
value="None",
choices=["None"] + get_custom_model_files("lora"),
allow_custom_value=True,
)
lora_hf_id = gr.Textbox(
elem_id="lora_hf_id",
@@ -446,7 +443,6 @@ with gr.Blocks(title="Upscaler") as upscaler_web:
label="Scheduler",
value="DDIM",
choices=scheduler_list_cpu_only,
allow_custom_value=True,
)
with gr.Group():
save_metadata_to_png = gr.Checkbox(
@@ -551,7 +547,6 @@ with gr.Blocks(title="Upscaler") as upscaler_web:
label="Device",
value=available_devices[0],
choices=available_devices,
allow_custom_value=True,
)
with gr.Row():
random_seed = gr.Button("Randomize Seed")

View File

@@ -25,7 +25,7 @@ class Config:
device: str
use_lora: str
use_stencil: str
ondemand: str # should this be expecting a bool instead?
ondemand: str
custom_model_filetypes = (

View File

@@ -24,13 +24,13 @@ def get_image(url, local_filename):
shutil.copyfileobj(res.raw, f)
def compare_images(new_filename, golden_filename, upload=False):
def compare_images(new_filename, golden_filename):
new = np.array(Image.open(new_filename)) / 255.0
golden = np.array(Image.open(golden_filename)) / 255.0
diff = np.abs(new - golden)
mean = np.mean(diff)
if mean > 0.1:
if os.name != "nt" and upload == True:
if os.name != "nt":
subprocess.run(
[
"gsutil",
@@ -39,7 +39,7 @@ def compare_images(new_filename, golden_filename, upload=False):
"gs://shark_tank/testdata/builder/",
]
)
raise AssertionError("new and golden not close")
raise SystemExit("new and golden not close")
else:
print("SUCCESS")

View File

@@ -1,6 +1,5 @@
#!/bin/bash
IMPORTER=1 BENCHMARK=1 NO_BREVITAS=1 ./setup_venv.sh
IMPORTER=1 BENCHMARK=1 ./setup_venv.sh
source $GITHUB_WORKSPACE/shark.venv/bin/activate
python build_tools/stable_diffusion_testing.py --gen
python tank/generate_sharktank.py

View File

@@ -63,14 +63,7 @@ def get_inpaint_inputs():
open("./test_images/inputs/mask.png", "wb").write(mask.content)
def test_loop(
device="vulkan",
beta=False,
extra_flags=[],
upload_bool=True,
exit_on_fail=True,
do_gen=False,
):
def test_loop(device="vulkan", beta=False, extra_flags=[]):
# Get golden values from tank
shutil.rmtree("./test_images", ignore_errors=True)
model_metrics = []
@@ -88,8 +81,6 @@ def test_loop(
if beta:
extra_flags.append("--beta_models=True")
extra_flags.append("--no-progress_bar")
if do_gen:
extra_flags.append("--import_debug")
to_skip = [
"Linaqruf/anything-v3.0",
"prompthero/openjourney",
@@ -190,14 +181,7 @@ def test_loop(
"./test_images/golden/" + model_name + "/*.png"
)
golden_file = glob(golden_path)[0]
try:
compare_images(
test_file, golden_file, upload=upload_bool
)
except AssertionError as e:
print(e)
if exit_on_fail == True:
raise
compare_images(test_file, golden_file)
else:
print(command)
print("failed to generate image for this configuration")
@@ -216,9 +200,6 @@ def test_loop(
extra_flags.remove(
"--iree_vulkan_target_triple=rdna2-unknown-windows"
)
if do_gen:
prepare_artifacts()
with open(os.path.join(os.getcwd(), "sd_testing_metrics.csv"), "w+") as f:
header = "model_name;device;use_tune;import_opt;Clip Inference time(ms);Average Step (ms/it);VAE Inference time(ms);total image generation(s);command\n"
f.write(header)
@@ -237,49 +218,15 @@ def test_loop(
f.write(";".join(output) + "\n")
def prepare_artifacts():
gen_path = os.path.join(os.getcwd(), "gen_shark_tank")
if not os.path.isdir(gen_path):
os.mkdir(gen_path)
for dirname in os.listdir(os.getcwd()):
for modelname in ["clip", "unet", "vae"]:
if modelname in dirname and "vmfb" not in dirname:
if not os.path.isdir(os.path.join(gen_path, dirname)):
shutil.move(os.path.join(os.getcwd(), dirname), gen_path)
print(f"Moved dir: {dirname} to {gen_path}.")
parser = argparse.ArgumentParser()
parser.add_argument("-d", "--device", default="vulkan")
parser.add_argument(
"-b", "--beta", action=argparse.BooleanOptionalAction, default=False
)
parser.add_argument("-e", "--extra_args", type=str, default=None)
parser.add_argument(
"-u", "--upload", action=argparse.BooleanOptionalAction, default=True
)
parser.add_argument(
"-x", "--exit_on_fail", action=argparse.BooleanOptionalAction, default=True
)
parser.add_argument(
"-g", "--gen", action=argparse.BooleanOptionalAction, default=False
)
if __name__ == "__main__":
args = parser.parse_args()
print(args)
extra_args = []
if args.extra_args:
for arg in args.extra_args.split(","):
extra_args.append(arg)
test_loop(
args.device,
args.beta,
extra_args,
args.upload,
args.exit_on_fail,
args.gen,
)
if args.gen:
prepare_artifacts()
test_loop(args.device, args.beta, [])

View File

@@ -27,7 +27,7 @@ include(FetchContent)
FetchContent_Declare(
iree
GIT_REPOSITORY https://github.com/nod-ai/srt.git
GIT_REPOSITORY https://github.com/nod-ai/shark-runtime.git
GIT_TAG shark
GIT_SUBMODULES_RECURSE OFF
GIT_SHALLOW OFF

View File

@@ -40,7 +40,7 @@ cmake --build build/
*Prepare the model*
```bash
wget https://storage.googleapis.com/shark_tank/latest/resnet50_tf/resnet50_tf.mlir
iree-compile --iree-input-type=auto --iree-vm-bytecode-module-output-format=flatbuffer-binary --iree-hal-target-backends=vulkan --iree-llvmcpu-embedded-linker-path=`python3 -c 'import sysconfig; print(sysconfig.get_paths()["purelib"])'`/iree/compiler/tools/../_mlir_libs/iree-lld --mlir-print-debuginfo --mlir-print-op-on-diagnostic=false --mlir-pass-pipeline-crash-reproducer=ist/core-reproducer.mlir --iree-llvmcpu-target-cpu-features=host -iree-vulkan-target-triple=rdna2-unknown-linux resnet50_tf.mlir -o resnet50_tf.vmfb
iree-compile --iree-input-type=auto --iree-vm-bytecode-module-output-format=flatbuffer-binary --iree-hal-target-backends=vulkan --iree-llvmcpu-embedded-linker-path=`python3 -c 'import sysconfig; print(sysconfig.get_paths()["purelib"])'`/iree/compiler/tools/../_mlir_libs/iree-lld --mlir-print-debuginfo --mlir-print-op-on-diagnostic=false --mlir-pass-pipeline-crash-reproducer=ist/core-reproducer.mlir --iree-llvmcpu-target-cpu-features=host -iree-vulkan-target-triple=rdna2-unknown-linux --iree-stream-resource-index-bits=64 --iree-vm-target-index-bits=64 resnet50_tf.mlir -o resnet50_tf.vmfb
```
*Prepare the input*
@@ -65,18 +65,18 @@ A tool for benchmarking other models is built and can be invoked with a command
see `./build/vulkan_gui/iree-vulkan-gui --help` for an explanation on the function input. For example, stable diffusion unet can be tested with the following commands:
```bash
wget https://storage.googleapis.com/shark_tank/quinn/stable_diff_tf/stable_diff_tf.mlir
iree-compile --iree-input-type=auto --iree-vm-bytecode-module-output-format=flatbuffer-binary --iree-hal-target-backends=vulkan --mlir-print-debuginfo --mlir-print-op-on-diagnostic=false --iree-llvmcpu-target-cpu-features=host -iree-vulkan-target-triple=rdna2-unknown-linux stable_diff_tf.mlir -o stable_diff_tf.vmfb
iree-compile --iree-input-type=auto --iree-vm-bytecode-module-output-format=flatbuffer-binary --iree-hal-target-backends=vulkan --mlir-print-debuginfo --mlir-print-op-on-diagnostic=false --iree-llvmcpu-target-cpu-features=host -iree-vulkan-target-triple=rdna2-unknown-linux --iree-stream-resource-index-bits=64 --iree-vm-target-index-bits=64 stable_diff_tf.mlir -o stable_diff_tf.vmfb
./build/vulkan_gui/iree-vulkan-gui --module-file=stable_diff_tf.vmfb --function_input=2x4x64x64xf32 --function_input=1xf32 --function_input=2x77x768xf32
```
VAE and Autoencoder are also available
```bash
# VAE
wget https://storage.googleapis.com/shark_tank/quinn/stable_diff_tf/vae_tf/vae.mlir
iree-compile --iree-input-type=auto --iree-vm-bytecode-module-output-format=flatbuffer-binary --iree-hal-target-backends=vulkan --mlir-print-debuginfo --mlir-print-op-on-diagnostic=false --iree-llvmcpu-target-cpu-features=host -iree-vulkan-target-triple=rdna2-unknown-linux vae.mlir -o vae.vmfb
iree-compile --iree-input-type=auto --iree-vm-bytecode-module-output-format=flatbuffer-binary --iree-hal-target-backends=vulkan --mlir-print-debuginfo --mlir-print-op-on-diagnostic=false --iree-llvmcpu-target-cpu-features=host -iree-vulkan-target-triple=rdna2-unknown-linux --iree-stream-resource-index-bits=64 --iree-vm-target-index-bits=64 vae.mlir -o vae.vmfb
./build/vulkan_gui/iree-vulkan-gui --module-file=stable_diff_tf.vmfb --function_input=1x4x64x64xf32
# CLIP Autoencoder
wget https://storage.googleapis.com/shark_tank/quinn/stable_diff_tf/clip_tf/clip_autoencoder.mlir
iree-compile --iree-input-type=auto --iree-vm-bytecode-module-output-format=flatbuffer-binary --iree-hal-target-backends=vulkan --mlir-print-debuginfo --mlir-print-op-on-diagnostic=false --iree-llvmcpu-target-cpu-features=host -iree-vulkan-target-triple=rdna2-unknown-linux clip_autoencoder.mlir -o clip_autoencoder.vmfb
iree-compile --iree-input-type=auto --iree-vm-bytecode-module-output-format=flatbuffer-binary --iree-hal-target-backends=vulkan --mlir-print-debuginfo --mlir-print-op-on-diagnostic=false --iree-llvmcpu-target-cpu-features=host -iree-vulkan-target-triple=rdna2-unknown-linux --iree-stream-resource-index-bits=64 --iree-vm-target-index-bits=64 clip_autoencoder.mlir -o clip_autoencoder.vmfb
./build/vulkan_gui/iree-vulkan-gui --module-file=stable_diff_tf.vmfb --function_input=1x77xi32 --function_input=1x77xi32
```

View File

@@ -55,7 +55,7 @@ The command line for compilation will start something like this, where the `-` n
The `-o output_filename.vmfb` flag can be used to specify the location to save the compiled vmfb. Note that a dump of the
dispatches that can be compiled + run in isolation can be generated by adding `--iree-hal-dump-executable-benchmarks-to=/some/directory`. Say, if they are in the `benchmarks` directory, the following compile/run commands would work for Vulkan on RDNA3.
```
iree-compile --iree-input-type=none --iree-hal-target-backends=vulkan --iree-vulkan-target-triple=rdna3-unknown-linux benchmarks/module_forward_dispatch_${NUM}_vulkan_spirv_fb.mlir -o benchmarks/module_forward_dispatch_${NUM}_vulkan_spirv_fb.vmfb
iree-compile --iree-input-type=none --iree-hal-target-backends=vulkan --iree-vulkan-target-triple=rdna3-unknown-linux --iree-stream-resource-index-bits=64 --iree-vm-target-index-bits=64 benchmarks/module_forward_dispatch_${NUM}_vulkan_spirv_fb.mlir -o benchmarks/module_forward_dispatch_${NUM}_vulkan_spirv_fb.vmfb
iree-benchmark-module --module=benchmarks/module_forward_dispatch_${NUM}_vulkan_spirv_fb.vmfb --function=forward --device=vulkan
```
@@ -63,8 +63,8 @@ Where `${NUM}` is the dispatch number that you want to benchmark/profile in isol
### Enabling Tracy for Vulkan profiling
To begin profiling with Tracy, a build of IREE runtime with tracing enabled is needed. SHARK-Runtime (SRT) builds an
instrumented version alongside the normal version nightly (.whls typically found [here](https://github.com/nod-ai/SRT/releases)), however this is only available for Linux. For Windows, tracing can be enabled by enabling a CMake flag.
To begin profiling with Tracy, a build of IREE runtime with tracing enabled is needed. SHARK-Runtime builds an
instrumented version alongside the normal version nightly (.whls typically found [here](https://github.com/nod-ai/SHARK-Runtime/releases)), however this is only available for Linux. For Windows, tracing can be enabled by enabling a CMake flag.
```
$env:IREE_ENABLE_RUNTIME_TRACING="ON"
```

192
inference/CMakeLists.txt Normal file
View File

@@ -0,0 +1,192 @@
# Copyright 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
# are met:
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
# * Neither the name of NVIDIA CORPORATION nor the names of its
# contributors may be used to endorse or promote products derived
# from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
cmake_minimum_required(VERSION 3.17)
project(sharkbackend LANGUAGES C CXX)
#
# Options
#
option(TRITON_ENABLE_GPU "Enable GPU support in backend" ON)
option(TRITON_ENABLE_STATS "Include statistics collections in backend" ON)
set(TRITON_COMMON_REPO_TAG "main" CACHE STRING "Tag for triton-inference-server/common repo")
set(TRITON_CORE_REPO_TAG "main" CACHE STRING "Tag for triton-inference-server/core repo")
set(TRITON_BACKEND_REPO_TAG "main" CACHE STRING "Tag for triton-inference-server/backend repo")
if(NOT CMAKE_BUILD_TYPE)
set(CMAKE_BUILD_TYPE Release)
endif()
#
# Dependencies
#
# FetchContent requires us to include the transitive closure of all
# repos that we depend on so that we can override the tags.
#
include(FetchContent)
FetchContent_Declare(
repo-common
GIT_REPOSITORY https://github.com/triton-inference-server/common.git
GIT_TAG ${TRITON_COMMON_REPO_TAG}
GIT_SHALLOW ON
)
FetchContent_Declare(
repo-core
GIT_REPOSITORY https://github.com/triton-inference-server/core.git
GIT_TAG ${TRITON_CORE_REPO_TAG}
GIT_SHALLOW ON
)
FetchContent_Declare(
repo-backend
GIT_REPOSITORY https://github.com/triton-inference-server/backend.git
GIT_TAG ${TRITON_BACKEND_REPO_TAG}
GIT_SHALLOW ON
)
FetchContent_MakeAvailable(repo-common repo-core repo-backend)
#
# The backend must be built into a shared library. Use an ldscript to
# hide all symbols except for the TRITONBACKEND API.
#
configure_file(src/libtriton_dshark.ldscript libtriton_dshark.ldscript COPYONLY)
add_library(
triton-dshark-backend SHARED
src/dshark.cc
#src/dshark_driver_module.c
)
add_library(
SharkBackend::triton-dshark-backend ALIAS triton-dshark-backend
)
target_include_directories(
triton-dshark-backend
PRIVATE
${CMAKE_CURRENT_SOURCE_DIR}/src
)
list(APPEND CMAKE_MODULE_PATH "${PROJECT_BINARY_DIR}/lib/cmake/mlir")
add_subdirectory(thirdparty/shark-runtime EXCLUDE_FROM_ALL)
target_link_libraries(triton-dshark-backend PRIVATE iree_base_base
iree_hal_hal
iree_hal_cuda_cuda
iree_hal_cuda_registration_registration
iree_hal_vmvx_registration_registration
iree_hal_dylib_registration_registration
iree_modules_hal_hal
iree_vm_vm
iree_vm_bytecode_module
iree_hal_local_loaders_system_library_loader
iree_hal_local_loaders_vmvx_module_loader
)
target_compile_features(triton-dshark-backend PRIVATE cxx_std_11)
target_link_libraries(
triton-dshark-backend
PRIVATE
triton-core-serverapi # from repo-core
triton-core-backendapi # from repo-core
triton-core-serverstub # from repo-core
triton-backend-utils # from repo-backend
)
if(WIN32)
set_target_properties(
triton-dshark-backend PROPERTIES
POSITION_INDEPENDENT_CODE ON
OUTPUT_NAME triton_dshark
)
else()
set_target_properties(
triton-dshark-backend PROPERTIES
POSITION_INDEPENDENT_CODE ON
OUTPUT_NAME triton_dshark
LINK_DEPENDS ${CMAKE_CURRENT_BINARY_DIR}/libtriton_dshark.ldscript
LINK_FLAGS "-Wl,--version-script libtriton_dshark.ldscript"
)
endif()
#
# Install
#
include(GNUInstallDirs)
set(INSTALL_CONFIGDIR ${CMAKE_INSTALL_LIBDIR}/cmake/SharkBackend)
install(
TARGETS
triton-dshark-backend
EXPORT
triton-dshark-backend-targets
LIBRARY DESTINATION ${CMAKE_INSTALL_PREFIX}/backends/dshark
RUNTIME DESTINATION ${CMAKE_INSTALL_PREFIX}/backends/dshark
)
install(
EXPORT
triton-dshark-backend-targets
FILE
SharkBackendTargets.cmake
NAMESPACE
SharkBackend::
DESTINATION
${INSTALL_CONFIGDIR}
)
include(CMakePackageConfigHelpers)
configure_package_config_file(
${CMAKE_CURRENT_LIST_DIR}/cmake/SharkBackendConfig.cmake.in
${CMAKE_CURRENT_BINARY_DIR}/SharkBackendConfig.cmake
INSTALL_DESTINATION ${INSTALL_CONFIGDIR}
)
install(
FILES
${CMAKE_CURRENT_BINARY_DIR}/SharkBackendConfig.cmake
DESTINATION ${INSTALL_CONFIGDIR}
)
#
# Export from build tree
#
export(
EXPORT triton-dshark-backend-targets
FILE ${CMAKE_CURRENT_BINARY_DIR}/SharkBackendTargets.cmake
NAMESPACE SharkBackend::
)
export(PACKAGE SharkBackend)

100
inference/README.md Normal file
View File

@@ -0,0 +1,100 @@
# SHARK Triton Backend
The triton backend for shark.
# Build
Install SHARK
```
git clone https://github.com/nod-ai/SHARK.git
# skip above step if dshark is already installed
cd SHARK/inference
```
install dependancies
```
apt-get install patchelf rapidjson-dev python3-dev
git submodule update --init
```
update the submodules of iree
```
cd thirdparty/shark-runtime
git submodule update --init
```
Next, make the backend and install it
```
cd ../..
mkdir build && cd build
cmake -DTRITON_ENABLE_GPU=ON \
-DIREE_HAL_DRIVER_CUDA=ON \
-DIREE_TARGET_BACKEND_CUDA=ON \
-DMLIR_ENABLE_CUDA_RUNNER=ON \
-DCMAKE_INSTALL_PREFIX:PATH=`pwd`/install \
-DTRITON_BACKEND_REPO_TAG=r22.02 \
-DTRITON_CORE_REPO_TAG=r22.02 \
-DTRITON_COMMON_REPO_TAG=r22.02 ..
make install
```
# Incorporating into Triton
There are much more in depth explenations for the following steps in triton's documentation:
https://github.com/triton-inference-server/server/blob/main/docs/compose.md#triton-with-unsupported-and-custom-backends
There should be a file at /build/install/backends/dshark/libtriton_dshark.so. You will need to copy it into your triton server image.
More documentation is in the link above, but to create the docker image, you need to run the compose.py command in the triton-backend server repo
To first build your image, clone the tritonserver repo.
```
git clone https://github.com/triton-inference-server/server.git
```
then run `compose.py` to build a docker compose file
```
cd server
python3 compose.py --repoagent checksum --dry-run
```
Because dshark is a third party backend, you will need to manually modify the `Dockerfile.compose` to include the dshark backend. To do this, in the Dockerfile.compose file produced, copy this line.
the dshark backend will be located in the build folder from earlier under `/build/install/backends`
```
COPY /path/to/build/install/backends/dshark /opt/tritonserver/backends/dshark
```
Next run
```
docker build -t tritonserver_custom -f Dockerfile.compose .
docker run -it --gpus=1 --net=host -v/path/to/model_repos:/models tritonserver_custom:latest tritonserver --model-repository=/models
```
where `path/to/model_repos` is where you are storing the models you want to run
if your not using gpus, omit `--gpus=1`
```
docker run -it --net=host -v/path/to/model_repos:/models tritonserver_custom:latest tritonserver --model-repository=/models
```
# Setting up a model
to include a model in your backend, add a directory with your model name to your model repository directory. examples of models can be seen here: https://github.com/triton-inference-server/backend/tree/main/examples/model_repos/minimal_models
make sure to adjust the input correctly in the config.pbtxt file, and save a vmfb file under 1/model.vmfb
# CUDA
if you're having issues with cuda, make sure your correct drivers are installed, and that `nvidia-smi` works, and also make sure that the nvcc compiler is on the path.

View File

@@ -0,0 +1,39 @@
# Copyright 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
# are met:
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
# * Neither the name of NVIDIA CORPORATION nor the names of its
# contributors may be used to endorse or promote products derived
# from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
include(CMakeFindDependencyMacro)
get_filename_component(
SHARKBACKEND_CMAKE_DIR "${CMAKE_CURRENT_LIST_FILE}" PATH
)
list(APPEND CMAKE_MODULE_PATH ${SHARKBACKEND_CMAKE_DIR})
if(NOT TARGET SharkBackend::triton-dshark-backend)
include("${SHARKBACKEND_CMAKE_DIR}/SharkBackendTargets.cmake")
endif()
set(SHARKBACKEND_LIBRARIES SharkBackend::triton-dshark-backend)

1409
inference/src/dshark.cc Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,30 @@
# Copyright 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
# are met:
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
# * Neither the name of NVIDIA CORPORATION nor the names of its
# contributors may be used to endorse or promote products derived
# from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
{
global:
TRITONBACKEND_*;
local: *;
};

View File

@@ -6,15 +6,15 @@ from distutils.sysconfig import get_python_lib
import fileinput
from pathlib import Path
# Temporary workaround for transformers/__init__.py.
path_to_transformers_hook = Path(
# Temorary workaround for transformers/__init__.py.
path_to_tranformers_hook = Path(
get_python_lib()
+ "/_pyinstaller_hooks_contrib/hooks/stdhooks/hook-transformers.py"
)
if path_to_transformers_hook.is_file():
if path_to_tranformers_hook.is_file():
pass
else:
with open(path_to_transformers_hook, "w") as f:
with open(path_to_tranformers_hook, "w") as f:
f.write("module_collection_mode = 'pyz+py'")
path_to_skipfiles = Path(get_python_lib() + "/torch/_dynamo/skipfiles.py")

View File

@@ -5,7 +5,7 @@ requires = [
"packaging",
"numpy>=1.22.4",
"torch-mlir>=20230620.875",
"torch-mlir>=20221021.633",
"iree-compiler>=20221022.190",
"iree-runtime>=20221022.190",
]

View File

@@ -8,8 +8,19 @@ torchvision
tqdm
#iree-compiler | iree-runtime should already be installed
#these dont work ok osx
#iree-tools-tflite
#iree-tools-xla
#iree-tools-tf
# TensorFlow and JAX.
gin-config
tensorflow-macos
tensorflow-metal
#tf-models-nightly
#tensorflow-text-nightly
transformers
tensorflow-probability
#jax[cpu]
# tflitehub dependencies.

View File

@@ -3,19 +3,29 @@
numpy>1.22.4
pytorch-triton
torchvision
torchvision==0.16.0.dev20230322
tabulate
tqdm
#iree-compiler | iree-runtime should already be installed
iree-tools-tflite
iree-tools-xla
iree-tools-tf
# Modelling and JAX.
# TensorFlow and JAX.
gin-config
tensorflow>2.11
keras
#tf-models-nightly
#tensorflow-text-nightly
transformers
diffusers
#tensorflow-probability
#jax[cpu]
# tflitehub dependencies.
Pillow
# Testing and support.

View File

@@ -1,6 +1,3 @@
-f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html
--pre
setuptools
wheel
@@ -18,18 +15,16 @@ Pillow
parameterized
# Add transformers, diffusers and scipy since it most commonly used
tokenizers==0.13.3
transformers
diffusers
#accelerate is now required for diffusers import from ckpt.
accelerate
scipy
ftfy
gradio==3.44.3
gradio
altair
omegaconf
# 0.3.2 doesn't have binaries for arm64
safetensors==0.3.1
safetensors
opencv-python
scikit-image
pytorch_lightning # for runwayml models
@@ -40,11 +35,10 @@ py-cpuinfo
tiktoken # for codegen
joblib # for langchain
timm # for MiniGPT4
langchain
# Keep PyInstaller at the end. Sometimes Windows Defender flags it but most folks can continue even if it errors
pefile
pyinstaller
# vicuna quantization
brevitas @ git+https://github.com/Xilinx/brevitas.git@56edf56a3115d5ac04f19837b388fd7d3b1ff7ea
brevitas @ git+https://github.com/Xilinx/brevitas.git@dev

View File

@@ -90,8 +90,8 @@ python -m pip install --upgrade pip
pip install wheel
pip install -r requirements.txt
pip install --pre torch-mlir torch --extra-index-url https://download.pytorch.org/whl/nightly/cpu -f https://llvm.github.io/torch-mlir/package-index/
pip install --upgrade -f https://nod-ai.github.io/SRT/pip-release-links.html iree-compiler iree-runtime
pip install --upgrade -f https://nod-ai.github.io/SHARK-Runtime/pip-release-links.html iree-compiler iree-runtime
Write-Host "Building SHARK..."
pip install -e . -f https://llvm.github.io/torch-mlir/package-index/ -f https://nod-ai.github.io/SRT/pip-release-links.html
pip install -e . -f https://llvm.github.io/torch-mlir/package-index/ -f https://nod-ai.github.io/SHARK-Runtime/pip-release-links.html
Write-Host "Build and installation completed successfully"
Write-Host "Source your venv with ./shark.venv/Scripts/activate"

View File

@@ -86,7 +86,6 @@ $PYTHON -m pip install --upgrade -r "$TD/requirements.txt"
if [ "$torch_mlir_bin" = true ]; then
if [[ $(uname -s) = 'Darwin' ]]; then
echo "MacOS detected. Installing torch-mlir from .whl, to avoid dependency problems with torch."
$PYTHON -m pip uninstall -y timm #TEMP FIX FOR MAC
$PYTHON -m pip install --pre --no-cache-dir torch-mlir -f https://llvm.github.io/torch-mlir/package-index/ -f https://download.pytorch.org/whl/nightly/torch/
else
$PYTHON -m pip install --pre torch-mlir -f https://llvm.github.io/torch-mlir/package-index/
@@ -104,7 +103,7 @@ else
fi
if [[ -z "${USE_IREE}" ]]; then
rm .use-iree
RUNTIME="https://nod-ai.github.io/SRT/pip-release-links.html"
RUNTIME="https://nod-ai.github.io/SHARK-Runtime/pip-release-links.html"
else
touch ./.use-iree
RUNTIME="https://openxla.github.io/iree/pip-release-links.html"
@@ -129,21 +128,16 @@ if [[ ! -z "${IMPORTER}" ]]; then
fi
fi
if [[ $(uname -s) = 'Darwin' ]]; then
PYTORCH_URL=https://download.pytorch.org/whl/nightly/torch/
else
PYTORCH_URL=https://download.pytorch.org/whl/nightly/cpu/
fi
$PYTHON -m pip install --no-warn-conflicts -e . -f https://llvm.github.io/torch-mlir/package-index/ -f ${RUNTIME} -f https://download.pytorch.org/whl/nightly/torch/
$PYTHON -m pip install --no-warn-conflicts -e . -f https://llvm.github.io/torch-mlir/package-index/ -f ${RUNTIME} -f ${PYTORCH_URL}
if [[ $(uname -s) = 'Linux' && ! -z "${IMPORTER}" ]]; then
if [[ $(uname -s) = 'Linux' && ! -z "${BENCHMARK}" ]]; then
T_VER=$($PYTHON -m pip show torch | grep Version)
T_VER_MIN=${T_VER:14:12}
TORCH_VERSION=${T_VER:9:17}
TV_VER=$($PYTHON -m pip show torchvision | grep Version)
TV_VER_MAJ=${TV_VER:9:6}
$PYTHON -m pip uninstall -y torchvision
$PYTHON -m pip install torchvision==${TV_VER_MAJ}${T_VER_MIN} --no-deps -f https://download.pytorch.org/whl/nightly/cpu/torchvision/
TV_VERSION=${TV_VER:9:18}
$PYTHON -m pip uninstall -y torch torchvision
$PYTHON -m pip install -U --pre --no-warn-conflicts triton
$PYTHON -m pip install --no-deps https://download.pytorch.org/whl/nightly/cu118/torch-${TORCH_VERSION}%2Bcu118-cp311-cp311-linux_x86_64.whl https://download.pytorch.org/whl/nightly/cu118/torchvision-${TV_VERSION}%2Bcu118-cp311-cp311-linux_x86_64.whl
if [ $? -eq 0 ];then
echo "Successfully Installed torch + cu118."
else
@@ -151,8 +145,14 @@ if [[ $(uname -s) = 'Linux' && ! -z "${IMPORTER}" ]]; then
fi
fi
if [[ -z "${NO_BREVITAS}" ]]; then
$PYTHON -m pip install git+https://github.com/Xilinx/brevitas.git@dev
if [[ ! -z "${ONNX}" ]]; then
echo "${Yellow}Installing ONNX and onnxruntime for benchmarks..."
$PYTHON -m pip install onnx onnxruntime psutil
if [ $? -eq 0 ];then
echo "Successfully installed ONNX and ONNX runtime."
else
echo "Could not install ONNX." >&2
fi
fi
if [[ -z "${CONDA_PREFIX}" && "$SKIP_VENV" != "1" ]]; then

View File

@@ -43,7 +43,9 @@ if __name__ == "__main__":
minilm_mlir, func_name = mlir_importer.import_mlir(
is_dynamic=False, tracing_required=True
)
shark_module = SharkInference(minilm_mlir)
shark_module = SharkInference(
minilm_mlir, func_name, mlir_dialect="linalg"
)
shark_module.compile()
token_logits = torch.tensor(shark_module.forward(inputs))
mask_id = torch.where(

View File

@@ -0,0 +1,325 @@
import torch
from torch.nn.utils import stateless
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from shark.shark_trainer import SharkTrainer
import argparse
import sys
import numpy as np
import torch.nn as nn
from shark.shark_inference import SharkInference
from shark.shark_importer import SharkImporter
import torch_mlir
import extend_distributed as ext_dist
### define dlrm in PyTorch ###
class DLRM_Net(nn.Module):
def create_mlp(self, ln, sigmoid_layer):
# build MLP layer by layer
layers = nn.ModuleList()
for i in range(0, ln.size - 1):
n = ln[i]
m = ln[i + 1]
# construct fully connected operator
LL = nn.Linear(int(n), int(m), bias=True)
# initialize the weights
# with torch.no_grad():
# custom Xavier input, output or two-sided fill
mean = 0.0 # std_dev = np.sqrt(variance)
std_dev = np.sqrt(2 / (m + n)) # np.sqrt(1 / m) # np.sqrt(1 / n)
W = np.random.normal(mean, std_dev, size=(m, n)).astype(np.float32)
std_dev = np.sqrt(1 / m) # np.sqrt(2 / (m + 1))
bt = np.random.normal(mean, std_dev, size=m).astype(np.float32)
LL.weight.data = torch.tensor(W, requires_grad=True)
LL.bias.data = torch.tensor(bt, requires_grad=True)
# approach 2
# LL.weight.data.copy_(torch.tensor(W))
# LL.bias.data.copy_(torch.tensor(bt))
# approach 3
# LL.weight = Parameter(torch.tensor(W),requires_grad=True)
# LL.bias = Parameter(torch.tensor(bt),requires_grad=True)
layers.append(LL)
# construct sigmoid or relu operator
if i == sigmoid_layer:
layers.append(nn.Sigmoid())
else:
layers.append(nn.ReLU())
# approach 1: use ModuleList
# return layers
# approach 2: use Sequential container to wrap all layers
return torch.nn.Sequential(*layers)
def create_emb(self, m, ln, weighted_pooling=None):
emb_l = nn.ModuleList()
v_W_l = []
for i in range(0, ln.size):
n = ln[i]
# construct embedding operator
EE = nn.EmbeddingBag(n, m, mode="sum")
# initialize embeddings
# nn.init.uniform_(EE.weight, a=-np.sqrt(1 / n), b=np.sqrt(1 / n))
W = np.random.uniform(
low=-np.sqrt(1 / n), high=np.sqrt(1 / n), size=(n, m)
).astype(np.float32)
# approach 1
print(W)
EE.weight.data = torch.tensor(W, requires_grad=True)
# approach 2
# EE.weight.data.copy_(torch.tensor(W))
# approach 3
# EE.weight = Parameter(torch.tensor(W),requires_grad=True)
if weighted_pooling is None:
v_W_l.append(None)
else:
v_W_l.append(torch.ones(n, dtype=torch.float32))
emb_l.append(EE)
return emb_l, v_W_l
def __init__(
self,
m_spa=None,
ln_emb=None,
ln_bot=None,
ln_top=None,
arch_interaction_op=None,
arch_interaction_itself=False,
sigmoid_bot=-1,
sigmoid_top=-1,
weighted_pooling=None,
):
super(DLRM_Net, self).__init__()
if (
(m_spa is not None)
and (ln_emb is not None)
and (ln_bot is not None)
and (ln_top is not None)
and (arch_interaction_op is not None)
):
# save arguments
self.output_d = 0
self.arch_interaction_op = arch_interaction_op
self.arch_interaction_itself = arch_interaction_itself
if weighted_pooling is not None and weighted_pooling != "fixed":
self.weighted_pooling = "learned"
else:
self.weighted_pooling = weighted_pooling
# create operators
self.emb_l, w_list = self.create_emb(
m_spa, ln_emb, weighted_pooling
)
if self.weighted_pooling == "learned":
self.v_W_l = nn.ParameterList()
for w in w_list:
self.v_W_l.append(nn.Parameter(w))
else:
self.v_W_l = w_list
self.bot_l = self.create_mlp(ln_bot, sigmoid_bot)
self.top_l = self.create_mlp(ln_top, sigmoid_top)
def apply_mlp(self, x, layers):
return layers(x)
def apply_emb(self, lS_o, lS_i, emb_l, v_W_l):
# WARNING: notice that we are processing the batch at once. We implicitly
# assume that the data is laid out such that:
# 1. each embedding is indexed with a group of sparse indices,
# corresponding to a single lookup
# 2. for each embedding the lookups are further organized into a batch
# 3. for a list of embedding tables there is a list of batched lookups
# TORCH-MLIR
# We are passing all the embeddings as arguments for easy parsing.
ly = []
for k, sparse_index_group_batch in enumerate(lS_i):
sparse_offset_group_batch = lS_o[k]
# embedding lookup
# We are using EmbeddingBag, which implicitly uses sum operator.
# The embeddings are represented as tall matrices, with sum
# happening vertically across 0 axis, resulting in a row vector
# E = emb_l[k]
if v_W_l[k] is not None:
per_sample_weights = v_W_l[k].gather(
0, sparse_index_group_batch
)
else:
per_sample_weights = None
E = emb_l[k]
V = E(
sparse_index_group_batch,
sparse_offset_group_batch,
per_sample_weights=per_sample_weights,
)
ly.append(V)
return ly
def interact_features(self, x, ly):
if self.arch_interaction_op == "dot":
# concatenate dense and sparse features
(batch_size, d) = x.shape
T = torch.cat([x] + ly, dim=1).view((batch_size, -1, d))
# perform a dot product
Z = torch.bmm(T, torch.transpose(T, 1, 2))
# append dense feature with the interactions (into a row vector)
# approach 1: all
# Zflat = Z.view((batch_size, -1))
# approach 2: unique
_, ni, nj = Z.shape
# approach 1: tril_indices
# offset = 0 if self.arch_interaction_itself else -1
# li, lj = torch.tril_indices(ni, nj, offset=offset)
# approach 2: custom
offset = 1 if self.arch_interaction_itself else 0
li = torch.tensor(
[i for i in range(ni) for j in range(i + offset)]
)
lj = torch.tensor(
[j for i in range(nj) for j in range(i + offset)]
)
Zflat = Z[:, li, lj]
# concatenate dense features and interactions
R = torch.cat([x] + [Zflat], dim=1)
elif self.arch_interaction_op == "cat":
# concatenation features (into a row vector)
R = torch.cat([x] + ly, dim=1)
else:
sys.exit(
"ERROR: --arch-interaction-op="
+ self.arch_interaction_op
+ " is not supported"
)
return R
def forward(self, dense_x, lS_o, *lS_i):
return self.sequential_forward(dense_x, lS_o, lS_i)
def sequential_forward(self, dense_x, lS_o, lS_i):
# process dense features (using bottom mlp), resulting in a row vector
x = self.apply_mlp(dense_x, self.bot_l)
# debug prints
# print("intermediate")
# print(x.detach().cpu().numpy())
# process sparse features(using embeddings), resulting in a list of row vectors
ly = self.apply_emb(lS_o, lS_i, self.emb_l, self.v_W_l)
# for y in ly:
# print(y.detach().cpu().numpy())
# interact features (dense and sparse)
z = self.interact_features(x, ly)
# print(z.detach().cpu().numpy())
# obtain probability of a click (using top mlp)
p = self.apply_mlp(z, self.top_l)
# # clamp output if needed
# if 0.0 < self.loss_threshold and self.loss_threshold < 1.0:
# z = torch.clamp(p, min=self.loss_threshold, max=(1.0 - self.loss_threshold))
# else:
# z = p
return p
def dash_separated_ints(value):
vals = value.split("-")
for val in vals:
try:
int(val)
except ValueError:
raise argparse.ArgumentTypeError(
"%s is not a valid dash separated list of ints" % value
)
return value
# model related parameters
parser = argparse.ArgumentParser(
description="Train Deep Learning Recommendation Model (DLRM)"
)
parser.add_argument("--arch-sparse-feature-size", type=int, default=2)
parser.add_argument(
"--arch-embedding-size", type=dash_separated_ints, default="4-3-2"
)
# j will be replaced with the table number
parser.add_argument(
"--arch-mlp-bot", type=dash_separated_ints, default="4-3-2"
)
parser.add_argument(
"--arch-mlp-top", type=dash_separated_ints, default="8-2-1"
)
parser.add_argument(
"--arch-interaction-op", type=str, choices=["dot", "cat"], default="dot"
)
parser.add_argument(
"--arch-interaction-itself", action="store_true", default=False
)
parser.add_argument("--weighted-pooling", type=str, default=None)
args = parser.parse_args()
ln_bot = np.fromstring(args.arch_mlp_bot, dtype=int, sep="-")
ln_top = np.fromstring(args.arch_mlp_top, dtype=int, sep="-")
m_den = ln_bot[0]
ln_emb = np.fromstring(args.arch_embedding_size, dtype=int, sep="-")
m_spa = args.arch_sparse_feature_size
ln_emb = np.asarray(ln_emb)
num_fea = ln_emb.size + 1 # num sparse + num dense features
# Initialize the model.
dlrm_model = DLRM_Net(
m_spa=m_spa,
ln_emb=ln_emb,
ln_bot=ln_bot,
ln_top=ln_top,
arch_interaction_op=args.arch_interaction_op,
)
def get_sorted_params(named_params):
return [i[1] for i in sorted(named_params.items())]
dense_inp = torch.tensor([[0.6965, 0.2861, 0.2269, 0.5513]])
vs0 = torch.tensor([[0], [0], [0]], dtype=torch.int64)
vsi = torch.tensor([1, 2, 3]), torch.tensor([1]), torch.tensor([1])
input_dlrm = (dense_inp, vs0, *vsi)
mlir_importer = SharkImporter(
dlrm_model,
input_dlrm,
frontend="torch",
)
dlrm_mlir = torch_mlir.compile(dlrm_model, input_dlrm, torch_mlir.OutputType.LINALG_ON_TENSORS, use_tracing=True)
print(dlrm_mlir)
def forward(params, buffers, args):
params_and_buffers = {**params, **buffers}
stateless.functional_call(
dlrm_model, params_and_buffers, args, {}
).sum().backward()
optim = torch.optim.SGD(get_sorted_params(params), lr=0.01)
# optim.load_state_dict(optim_state)
optim.step()
return params, buffers
shark_module = SharkTrainer(dlrm_model, input_dlrm)
print("________________________________________________________________________")
shark_module.compile(forward)
print("________________________________________________________________________")
shark_module.train(num_iters=2)
print("training done")

View File

@@ -13,7 +13,7 @@
# limitations under the License.
## Common utilities to be shared by iree utilities.
import functools
import os
import sys
import subprocess
@@ -52,8 +52,6 @@ def iree_device_map(device):
)
if len(uri_parts) == 1:
return iree_driver
elif "rocm" in uri_parts:
return "rocm"
else:
return f"{iree_driver}://{uri_parts[1]}"
@@ -65,6 +63,7 @@ def get_supported_device_list():
_IREE_DEVICE_MAP = {
"cpu": "local-task",
"cpu-task": "local-task",
"AMD-AIE": "local-task",
"cpu-sync": "local-sync",
"cuda": "cuda",
"vulkan": "vulkan",
@@ -83,6 +82,7 @@ def iree_target_map(device):
_IREE_TARGET_MAP = {
"cpu": "llvm-cpu",
"cpu-task": "llvm-cpu",
"AMD-AIE": "llvm-cpu",
"cpu-sync": "llvm-cpu",
"cuda": "cuda",
"vulkan": "vulkan",
@@ -93,7 +93,6 @@ _IREE_TARGET_MAP = {
# Finds whether the required drivers are installed for the given device.
@functools.cache
def check_device_drivers(device):
"""Checks necessary drivers present for gpu and vulkan devices"""
if "://" in device:
@@ -121,10 +120,7 @@ def check_device_drivers(device):
return False
elif device == "rocm":
try:
if sys.platform == "win32":
subprocess.check_output("hipinfo")
else:
subprocess.check_output("rocminfo")
subprocess.check_output("rocminfo")
except Exception:
return True

View File

@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import iree.runtime.scripts.iree_benchmark_module as benchmark_module
from shark.iree_utils._common import run_cmd, iree_device_map
from shark.iree_utils.cpu_utils import get_cpu_count
import numpy as np
@@ -61,12 +62,16 @@ def build_benchmark_args(
and whether it is training or not.
Outputs: string that execute benchmark-module on target model.
"""
path = os.path.join(os.environ["VIRTUAL_ENV"], "bin")
path = benchmark_module.__path__[0]
if platform.system() == "Windows":
benchmarker_path = os.path.join(path, "iree-benchmark-module.exe")
benchmarker_path = os.path.join(
path, "..", "..", "iree-benchmark-module.exe"
)
time_extractor = None
else:
benchmarker_path = os.path.join(path, "iree-benchmark-module")
benchmarker_path = os.path.join(
path, "..", "..", "iree-benchmark-module"
)
time_extractor = "| awk 'END{{print $2 $3}}'"
benchmark_cl = [benchmarker_path, f"--module={input_file}"]
# TODO: The function named can be passed as one of the args.
@@ -101,13 +106,15 @@ def build_benchmark_args_non_tensor_input(
and whether it is training or not.
Outputs: string that execute benchmark-module on target model.
"""
path = os.path.join(os.environ["VIRTUAL_ENV"], "bin")
path = benchmark_module.__path__[0]
if platform.system() == "Windows":
benchmarker_path = os.path.join(path, "iree-benchmark-module.exe")
time_extractor = None
benchmarker_path = os.path.join(
path, "..", "..", "iree-benchmark-module.exe"
)
else:
benchmarker_path = os.path.join(path, "iree-benchmark-module")
time_extractor = "| awk 'END{{print $2 $3}}'"
benchmarker_path = os.path.join(
path, "..", "..", "iree-benchmark-module"
)
benchmark_cl = [benchmarker_path, f"--module={input_file}"]
# TODO: The function named can be passed as one of the args.
if function_name:
@@ -132,7 +139,7 @@ def run_benchmark_module(benchmark_cl):
benchmark_path = benchmark_cl[0]
assert os.path.exists(
benchmark_path
), "Cannot find iree_benchmark_module, Please contact SHARK maintainer on discord."
), "Cannot find benchmark_module, Please contact SHARK maintainer on discord."
bench_stdout, bench_stderr = run_cmd(" ".join(benchmark_cl))
try:
regex_split = re.compile("(\d+[.]*\d*)( *)([a-zA-Z]+)")

View File

@@ -11,23 +11,18 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import functools
import iree.runtime as ireert
import iree.compiler as ireec
from shark.iree_utils._common import iree_device_map, iree_target_map
from shark.iree_utils.cpu_utils import get_iree_cpu_rt_args
from shark.iree_utils.benchmark_utils import *
from shark.parser import shark_args
import numpy as np
import os
import re
import tempfile
import time
from pathlib import Path
import iree.runtime as ireert
import iree.compiler as ireec
from shark.parser import shark_args
from .trace import DetailLogger
from ._common import iree_device_map, iree_target_map
from .cpu_utils import get_iree_cpu_rt_args
from .benchmark_utils import *
# Get the iree-compile arguments given device.
def get_iree_device_args(device, extra_args=[]):
@@ -46,7 +41,7 @@ def get_iree_device_args(device, extra_args=[]):
if device_uri[0] == "cpu":
from shark.iree_utils.cpu_utils import get_iree_cpu_args
data_tiling_flag = ["--iree-opt-data-tiling"]
data_tiling_flag = ["--iree-flow-enable-data-tiling"]
u_kernel_flag = ["--iree-llvmcpu-enable-microkernels"]
stack_size_flag = ["--iree-llvmcpu-stack-allocation-limit=256000"]
@@ -84,7 +79,7 @@ def get_iree_frontend_args(frontend):
elif frontend in ["tensorflow", "tf", "mhlo", "stablehlo"]:
return [
"--iree-llvmcpu-target-cpu-features=host",
"--iree-input-demote-i64-to-i32",
"--iree-flow-demote-i64-to-i32",
]
else:
# Frontend not found.
@@ -92,27 +87,13 @@ def get_iree_frontend_args(frontend):
# Common args to be used given any frontend or device.
def get_iree_common_args(debug=False):
common_args = [
"--iree-stream-resource-max-allocation-size=4294967295",
def get_iree_common_args():
return [
"--iree-stream-resource-index-bits=64",
"--iree-vm-target-index-bits=64",
"--iree-vm-bytecode-module-strip-source-map=true",
"--iree-util-zero-fill-elided-attrs",
]
if debug == True:
common_args.extend(
[
"--iree-opt-strip-assertions=false",
"--verify=true",
]
)
else:
common_args.extend(
[
"--iree-opt-strip-assertions=true",
"--verify=false",
]
)
return common_args
# Args that are suitable only for certain models or groups of models.
@@ -291,17 +272,14 @@ def compile_module_to_flatbuffer(
model_config_path,
extra_args,
model_name="None",
debug=False,
compile_str=False,
):
# Setup Compile arguments wrt to frontends.
input_type = "auto"
input_type = ""
args = get_iree_frontend_args(frontend)
args += get_iree_device_args(device, extra_args)
args += get_iree_common_args(debug=debug)
args += get_iree_common_args()
args += get_model_specific_args()
args += extra_args
args += shark_args.additional_compile_args
if frontend in ["tensorflow", "tf"]:
input_type = "auto"
@@ -312,7 +290,10 @@ def compile_module_to_flatbuffer(
elif frontend in ["tm_tensor"]:
input_type = ireec.InputType.TM_TENSOR
if compile_str:
# TODO: make it simpler.
# Compile according to the input type, else just try compiling.
if input_type != "":
# Currently for MHLO/TOSA.
flatbuffer_blob = ireec.compile_str(
module,
target_backends=[iree_target_map(device)],
@@ -320,10 +301,9 @@ def compile_module_to_flatbuffer(
input_type=input_type,
)
else:
assert os.path.isfile(module)
flatbuffer_blob = ireec.compile_file(
# Currently for Torch.
flatbuffer_blob = ireec.compile_str(
module,
input_type=input_type,
target_backends=[iree_target_map(device)],
extra_args=args,
)
@@ -337,6 +317,7 @@ def get_iree_module(flatbuffer_blob, device, device_idx=None):
device = iree_device_map(device)
print("registering device id: ", device_idx)
haldriver = ireert.get_driver(device)
haldevice = haldriver.create_device(
haldriver.query_available_devices()[device_idx]["device_id"],
allocators=shark_args.device_allocator,
@@ -356,70 +337,58 @@ def get_iree_module(flatbuffer_blob, device, device_idx=None):
def load_vmfb_using_mmap(
flatbuffer_blob_or_path, device: str, device_idx: int = None
):
print(f"Loading module {flatbuffer_blob_or_path}...")
if "rocm" in device:
device = "rocm"
with DetailLogger(timeout=2.5) as dl:
# First get configs.
if device_idx is not None:
dl.log(f"Mapping device id: {device_idx}")
device = iree_device_map(device)
haldriver = ireert.get_driver(device)
dl.log(f"ireert.get_driver()")
instance = ireert.VmInstance()
device = iree_device_map(device)
haldriver = ireert.get_driver(device)
haldevice = haldriver.create_device_by_uri(
device,
allocators=[],
)
# First get configs.
if device_idx is not None:
device = iree_device_map(device)
print("registering device id: ", device_idx)
haldriver = ireert.get_driver(device)
haldevice = haldriver.create_device(
haldriver.query_available_devices()[device_idx]["device_id"],
allocators=shark_args.device_allocator,
)
dl.log(f"ireert.create_device()")
config = ireert.Config(device=haldevice)
dl.log(f"ireert.Config()")
else:
config = get_iree_runtime_config(device)
dl.log("get_iree_runtime_config")
if "task" in device:
print(
f"[DEBUG] setting iree runtime flags for cpu:\n{' '.join(get_iree_cpu_rt_args())}"
)
for flag in get_iree_cpu_rt_args():
ireert.flags.parse_flags(flag)
# Now load vmfb.
# Two scenarios we have here :-
# 1. We either have the vmfb already saved and therefore pass the path of it.
# (This would arise if we're invoking `load_module` from a SharkInference obj)
# OR 2. We are compiling on the fly, therefore we have the flatbuffer blob to play with.
# (This would arise if we're invoking `compile` from a SharkInference obj)
temp_file_to_unlink = None
if isinstance(flatbuffer_blob_or_path, Path):
flatbuffer_blob_or_path = flatbuffer_blob_or_path.__str__()
if (
isinstance(flatbuffer_blob_or_path, str)
and ".vmfb" in flatbuffer_blob_or_path
):
vmfb_file_path = flatbuffer_blob_or_path
mmaped_vmfb = ireert.VmModule.mmap(
config.vm_instance, flatbuffer_blob_or_path
)
dl.log(f"mmap {flatbuffer_blob_or_path}")
ctx = ireert.SystemContext(config=config)
dl.log(f"ireert.SystemContext created")
if "vulkan" in device:
# Vulkan pipeline creation consumes significant amount of time.
print(
"\tCompiling Vulkan shaders. This may take a few minutes."
)
ctx.add_vm_module(mmaped_vmfb)
dl.log(f"module initialized")
mmaped_vmfb = getattr(ctx.modules, mmaped_vmfb.name)
else:
with tempfile.NamedTemporaryFile(delete=False) as tf:
tf.write(flatbuffer_blob_or_path)
tf.flush()
vmfb_file_path = tf.name
temp_file_to_unlink = vmfb_file_path
mmaped_vmfb = ireert.VmModule.mmap(instance, vmfb_file_path)
dl.log(f"mmap temp {vmfb_file_path}")
return mmaped_vmfb, config, temp_file_to_unlink
haldevice = haldriver.create_device(
haldriver.query_available_devices()[device_idx]["device_id"],
allocators=shark_args.device_allocator,
)
config = ireert.Config(device=haldevice)
else:
config = get_iree_runtime_config(device)
if "task" in device:
print(
f"[DEBUG] setting iree runtime flags for cpu:\n{' '.join(get_iree_cpu_rt_args())}"
)
for flag in get_iree_cpu_rt_args():
ireert.flags.parse_flags(flag)
# Now load vmfb.
# Two scenarios we have here :-
# 1. We either have the vmfb already saved and therefore pass the path of it.
# (This would arise if we're invoking `load_module` from a SharkInference obj)
# OR 2. We are compiling on the fly, therefore we have the flatbuffer blob to play with.
# (This would arise if we're invoking `compile` from a SharkInference obj)
temp_file_to_unlink = None
if isinstance(flatbuffer_blob_or_path, Path):
flatbuffer_blob_or_path = flatbuffer_blob_or_path.__str__()
if (
isinstance(flatbuffer_blob_or_path, str)
and ".vmfb" in flatbuffer_blob_or_path
):
vmfb_file_path = flatbuffer_blob_or_path
mmaped_vmfb = ireert.VmModule.mmap(instance, flatbuffer_blob_or_path)
ctx = ireert.SystemContext(config=config)
ctx.add_vm_module(mmaped_vmfb)
mmaped_vmfb = getattr(ctx.modules, mmaped_vmfb.name)
else:
with tempfile.NamedTemporaryFile(delete=False) as tf:
tf.write(flatbuffer_blob_or_path)
tf.flush()
vmfb_file_path = tf.name
temp_file_to_unlink = vmfb_file_path
mmaped_vmfb = ireert.VmModule.mmap(instance, vmfb_file_path)
return mmaped_vmfb, config, temp_file_to_unlink
def get_iree_compiled_module(
@@ -430,18 +399,10 @@ def get_iree_compiled_module(
extra_args: list = [],
device_idx: int = None,
mmap: bool = False,
debug: bool = False,
compile_str: bool = False,
):
"""Given a module returns the compiled .vmfb and configs"""
flatbuffer_blob = compile_module_to_flatbuffer(
module,
device,
frontend,
model_config_path,
extra_args,
debug,
compile_str,
module, device, frontend, model_config_path, extra_args
)
temp_file_to_unlink = None
# TODO: Currently mmap=True control flow path has been switched off for mmap.
@@ -449,6 +410,7 @@ def get_iree_compiled_module(
# we're setting delete=False when creating NamedTemporaryFile. That's why
# I'm getting hold of the name of the temporary file in `temp_file_to_unlink`.
if mmap:
print(f"Will load the compiled module as a mmapped temporary file")
vmfb, config, temp_file_to_unlink = load_vmfb_using_mmap(
flatbuffer_blob, device, device_idx
)
@@ -472,6 +434,7 @@ def load_flatbuffer(
):
temp_file_to_unlink = None
if mmap:
print(f"Loading flatbuffer at {flatbuffer_path} as a mmapped file")
vmfb, config, temp_file_to_unlink = load_vmfb_using_mmap(
flatbuffer_path, device, device_idx
)
@@ -497,18 +460,10 @@ def export_iree_module_to_vmfb(
model_config_path: str = None,
module_name: str = None,
extra_args: list = [],
debug: bool = False,
compile_str: bool = False,
):
# Compiles the module given specs and saves it as .vmfb file.
flatbuffer_blob = compile_module_to_flatbuffer(
module,
device,
mlir_dialect,
model_config_path,
extra_args,
debug,
compile_str,
module, device, mlir_dialect, model_config_path, extra_args
)
if module_name is None:
device_name = (
@@ -516,9 +471,9 @@ def export_iree_module_to_vmfb(
)
module_name = f"{mlir_dialect}_{device_name}"
filename = os.path.join(directory, module_name + ".vmfb")
print(f"Saved vmfb in {filename}.")
with open(filename, "wb") as f:
f.write(flatbuffer_blob)
print(f"Saved vmfb in {filename}.")
return filename
@@ -543,56 +498,37 @@ def get_results(
config,
frontend="torch",
send_to_host=True,
debug_timeout: float = 5.0,
):
"""Runs a .vmfb file given inputs and config and returns output."""
with DetailLogger(debug_timeout) as dl:
device_inputs = []
for input_array in input:
dl.log(f"Load to device: {input_array.shape}")
device_inputs.append(
ireert.asdevicearray(config.device, input_array)
)
dl.log(f"Invoke function: {function_name}")
result = compiled_vm[function_name](*device_inputs)
dl.log(f"Invoke complete")
result_tensors = []
if isinstance(result, tuple):
if send_to_host:
for val in result:
dl.log(f"Result to host: {val.shape}")
result_tensors.append(np.asarray(val, val.dtype))
else:
for val in result:
result_tensors.append(val)
return result_tensors
elif isinstance(result, dict):
data = list(result.items())
if send_to_host:
res = np.array(data, dtype=object)
return np.copy(res)
return data
device_inputs = [ireert.asdevicearray(config.device, a) for a in input]
result = compiled_vm[function_name](*device_inputs)
result_tensors = []
if isinstance(result, tuple):
if send_to_host:
for val in result:
result_tensors.append(np.asarray(val, val.dtype))
else:
if send_to_host and result is not None:
dl.log("Result to host")
return result.to_host()
return result
dl.log("Execution complete")
for val in result:
result_tensors.append(val)
return result_tensors
elif isinstance(result, dict):
data = list(result.items())
if send_to_host:
res = np.array(data, dtype=object)
return np.copy(res)
return data
else:
if send_to_host and result is not None:
return result.to_host()
return result
@functools.cache
def get_iree_runtime_config(device):
device = iree_device_map(device)
haldriver = ireert.get_driver(device)
if device == "metal" and shark_args.device_allocator == "caching":
print(
"[WARNING] metal devices can not have a `caching` allocator."
"\nUsing default allocator `None`"
)
haldevice = haldriver.create_device_by_uri(
device,
# metal devices have a failure with caching allocators atm. blcking this util it gets fixed upstream.
allocators=shark_args.device_allocator if device != "metal" else None,
allocators=shark_args.device_allocator,
)
config = ireert.Config(device=haldevice)
return config

View File

@@ -14,7 +14,6 @@
# All the iree_cpu related functionalities go here.
import functools
import subprocess
import platform
from shark.parser import shark_args
@@ -31,7 +30,6 @@ def get_cpu_count():
# Get the default cpu args.
@functools.cache
def get_iree_cpu_args():
uname = platform.uname()
os_name, proc_name = uname.system, uname.machine
@@ -53,7 +51,6 @@ def get_iree_cpu_args():
# Get iree runtime flags for cpu
@functools.cache
def get_iree_cpu_rt_args():
default = get_cpu_count()
default = default if default <= 8 else default - 2

View File

@@ -14,15 +14,12 @@
# All the iree_gpu related functionalities go here.
import functools
import iree.runtime as ireert
import ctypes
import sys
from shark.parser import shark_args
# Get the default gpu args given the architecture.
@functools.cache
def get_iree_gpu_args():
ireert.flags.FUNCTION_INPUT_VALIDATION = False
ireert.flags.parse_flags("--cuda_allow_inline_execution")
@@ -40,54 +37,23 @@ def get_iree_gpu_args():
# Get the default gpu args given the architecture.
@functools.cache
def get_iree_rocm_args():
ireert.flags.FUNCTION_INPUT_VALIDATION = False
# get arch from hipinfo.
import os
# get arch from rocminfo.
import re
import subprocess
if sys.platform == "win32":
if "HIP_PATH" in os.environ:
rocm_path = os.environ["HIP_PATH"]
print(f"Found a ROCm installation at {rocm_path}.")
else:
print("Failed to find ROCM_PATH. Defaulting to C:\\AMD\\ROCM\\5.5")
rocm_path = "C:\\AMD\\ROCM\\5.5"
else:
if "ROCM_PATH" in os.environ:
rocm_path = os.environ["ROCM_PATH"]
print(f"Found a ROCm installation at {rocm_path}.")
else:
print("Failed to find ROCM_PATH. Defaulting to /opt/rocm")
rocm_path = "/opt/rocm/"
try:
if sys.platform == "win32":
rocm_arch = re.search(
r"gfx\d{3,}",
subprocess.check_output("hipinfo", shell=True, text=True),
).group(0)
else:
rocm_arch = re.match(
r".*(gfx\w+)",
subprocess.check_output(
"rocminfo | grep -i 'gfx'", shell=True, text=True
),
).group(1)
print(f"Found rocm arch {rocm_arch}...")
except:
print(
"Failed to find ROCm architecture from hipinfo / rocminfo. Defaulting to gfx1100."
)
rocm_arch = "gfx1100"
bc_path = os.path.join(rocm_path, "amdgcn", "bitcode")
rocm_arch = re.match(
r".*(gfx\w+)",
subprocess.check_output(
"rocminfo | grep -i 'gfx'", shell=True, text=True
),
).group(1)
print(f"Found rocm arch {rocm_arch}...")
return [
f"--iree-rocm-target-chip={rocm_arch}",
"--iree-rocm-link-bc=true",
f"--iree-rocm-bc-dir={bc_path}",
"--iree-rocm-bc-dir=/opt/rocm/amdgcn/bitcode",
]
@@ -99,7 +65,6 @@ CU_DEVICE_ATTRIBUTE_CLOCK_RATE = 13
CU_DEVICE_ATTRIBUTE_MEMORY_CLOCK_RATE = 36
@functools.cache
def get_cuda_sm_cc():
libnames = ("libcuda.so", "libcuda.dylib", "nvcuda.dll")
for libname in libnames:

View File

@@ -14,15 +14,12 @@
# All the iree_vulkan related functionalities go here.
import functools
from shark.iree_utils._common import run_cmd
import iree.runtime as ireert
from sys import platform
from shark.iree_utils.vulkan_target_env_utils import get_vulkan_target_env_flag
@functools.cache
def get_metal_device_name(device_num=0):
iree_device_dump = run_cmd("iree-run-module --dump_devices")
iree_device_dump = iree_device_dump[0].split("\n\n")
@@ -89,10 +86,24 @@ def get_metal_triple_flag(device_name="", device_num=0, extra_args=[]):
def get_iree_metal_args(device_num=0, extra_args=[]):
# Add any metal spefic compilation flags here
# res_metal_flag = ["--iree-flow-demote-i64-to-i32"]
res_metal_flag = []
if len(extra_args) > 0:
res_metal_flag.extend(extra_args)
metal_triple_flag = None
for arg in extra_args:
if "-iree-metal-target-platform=" in arg:
print(f"Using target triple {arg} from command line args")
metal_triple_flag = arg
break
if metal_triple_flag is None:
metal_triple_flag = get_metal_triple_flag(extra_args=extra_args)
if metal_triple_flag is not None:
vulkan_target_env = get_vulkan_target_env_flag(
"-iree-vulkan-target-triple=m1-moltenvk-macos"
)
res_metal_flag.append(vulkan_target_env)
return res_metal_flag

View File

@@ -1,76 +0,0 @@
# Copyright 2023 The Nod Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Tuple
import os
import threading
import time
def _enable_detail_trace() -> bool:
return os.getenv("SHARK_DETAIL_TRACE", "0") == "1"
class DetailLogger:
"""Context manager which can accumulate detailed log messages.
Detailed log is only emitted if the operation takes a long time
or errors.
"""
def __init__(self, timeout: float):
self._timeout = timeout
self._messages: List[Tuple[float, str]] = []
self._start_time = time.time()
self._active = not _enable_detail_trace()
self._lock = threading.RLock()
self._cond = threading.Condition(self._lock)
self._thread = None
def __enter__(self):
self._thread = threading.Thread(target=self._run)
self._thread.start()
return self
def __exit__(self, type, value, traceback):
with self._lock:
self._active = False
self._cond.notify()
if traceback:
self.dump_on_error(f"exception")
def _run(self):
with self._lock:
timed_out = not self._cond.wait(self._timeout)
if timed_out:
self.dump_on_error(f"took longer than {self._timeout}s")
def log(self, msg):
with self._lock:
timestamp = time.time()
if self._active:
self._messages.append((timestamp, msg))
else:
print(f" +{(timestamp - self._start_time) * 1000}ms: {msg}")
def dump_on_error(self, summary: str):
with self._lock:
if self._active:
print(f"::: Detailed report ({summary}):")
for timestamp, msg in self._messages:
print(
f" +{(timestamp - self._start_time) * 1000}ms: {msg}"
)
self._active = False

View File

@@ -13,10 +13,8 @@
# limitations under the License.
from collections import OrderedDict
import functools
@functools.cache
def get_vulkan_target_env(vulkan_target_triple):
arch, product, os = vulkan_target_triple.split("=")[1].split("-")
triple = (arch, product, os)
@@ -54,11 +52,13 @@ def get_version(triple):
return "v1.3"
@functools.cache
def get_extensions(triple):
def make_ext_list(ext_list):
res = ", ".join(ext_list)
return f"[{res}]"
res = ""
for e in ext_list:
res += e + ", "
res = f"[{res[:-2]}]"
return res
arch, product, os = triple
if arch == "m1":
@@ -116,13 +116,12 @@ def get_extensions(triple):
]
if get_vendor(triple) == "NVIDIA" or arch == "rdna3":
ext.append("VK_KHR_cooperative_matrix")
ext.append("VK_NV_cooperative_matrix")
if get_vendor(triple) == ["NVIDIA", "AMD", "Intel"]:
ext.append("VK_KHR_shader_integer_dot_product")
return make_ext_list(ext_list=ext)
@functools.cache
def get_vendor(triple):
arch, product, os = triple
if arch == "unknown":
@@ -147,7 +146,6 @@ def get_vendor(triple):
return "Unknown"
@functools.cache
def get_device_type(triple):
arch, product, _ = triple
if arch == "unknown":
@@ -168,7 +166,6 @@ def get_device_type(triple):
# get all the capabilities for the device
# TODO: make a dataclass for capabilites and init using vulkaninfo
@functools.cache
def get_vulkan_target_capabilities(triple):
def get_subgroup_val(l):
return int(sum([subgroup_feature[sgf] for sgf in l]))
@@ -244,7 +241,7 @@ def get_vulkan_target_capabilities(triple):
if arch == "rdna3":
# TODO: Get scope value
cap["coopmatCases"] = [
"mSize = 16, nSize = 16, kSize = 16, aType = f16, bType = f16, cType = f16, resultType = f16, accSat = false, scope = #vk.scope<Subgroup>"
"mSize = 16, nSize = 16, kSize = 16, aType = f16, bType = f16, cType = f16, resultType = f16, scope = #vk.scope<Subgroup>"
]
if product == "rx5700xt":
@@ -465,9 +462,9 @@ def get_vulkan_target_capabilities(triple):
cap["variablePointersStorageBuffer"] = True
cap["coopmatCases"] = [
"mSize = 8, nSize = 8, kSize = 32, aType = i8, bType = i8, cType = i32, resultType = i32, accSat = false, scope = #vk.scope<Subgroup>",
"mSize = 16, nSize = 16, kSize = 16, aType = f16, bType = f16, cType = f16, resultType = f16, accSat = false, scope = #vk.scope<Subgroup>",
"mSize = 16, nSize = 16, kSize = 16, aType = f16, bType = f16, cType = f32, resultType = f32, accSat = false, scope = #vk.scope<Subgroup>",
"mSize = 8, nSize = 8, kSize = 32, aType = i8, bType = i8, cType = i32, resultType = i32, scope = #vk.scope<Subgroup>",
"mSize = 16, nSize = 16, kSize = 16, aType = f16, bType = f16, cType = f16, resultType = f16, scope = #vk.scope<Subgroup>",
"mSize = 16, nSize = 16, kSize = 16, aType = f16, bType = f16, cType = f32, resultType = f32, scope = #vk.scope<Subgroup>",
]
elif arch == "adreno":
@@ -528,7 +525,7 @@ def get_vulkan_target_capabilities(triple):
cmc = ""
for case in v:
cmc += f"#vk.coop_matrix_props<{case}>, "
res += f"cooperativeMatrixPropertiesKHR = [{cmc[:-2]}], "
res += f"cooperativeMatrixPropertiesNV = [{cmc[:-2]}], "
else:
res += f"{k} = {get_comma_sep_str(v)}, "
else:

View File

@@ -14,7 +14,6 @@
# All the iree_vulkan related functionalities go here.
import functools
from os import linesep
from shark.iree_utils._common import run_cmd
import iree.runtime as ireert
@@ -23,19 +22,10 @@ from shark.iree_utils.vulkan_target_env_utils import get_vulkan_target_env_flag
from shark.parser import shark_args
@functools.cache
def get_all_vulkan_devices():
from iree.runtime import get_driver
driver = get_driver("vulkan")
device_list_src = driver.query_available_devices()
device_list_src.sort(key=lambda d: d["path"])
return [d["name"] for d in device_list_src]
@functools.cache
def get_vulkan_device_name(device_num=0):
vulkaninfo_list = get_all_vulkan_devices()
vulkaninfo_dump, _ = run_cmd("vulkaninfo")
vulkaninfo_dump = vulkaninfo_dump.split(linesep)
vulkaninfo_list = [s.strip() for s in vulkaninfo_dump if "deviceName" in s]
if len(vulkaninfo_list) == 0:
raise ValueError("No device name found in VulkanInfo!")
if len(vulkaninfo_list) > 1:
@@ -58,7 +48,6 @@ def get_os_name():
return "linux"
@functools.cache
def get_vulkan_target_triple(device_name):
"""This method provides a target triple str for specified vulkan device.
@@ -119,8 +108,6 @@ def get_vulkan_target_triple(device_name):
# Windows: AMD Radeon RX 7900 XTX
elif all(x in device_name for x in ("RX", "7900")):
triple = f"rdna3-7900-{system_os}"
elif all(x in device_name for x in ("Radeon", "780M")):
triple = f"rdna3-780m-{system_os}"
elif all(x in device_name for x in ("AMD", "PRO", "W7900")):
triple = f"rdna3-w7900-{system_os}"
elif any(x in device_name for x in ("AMD", "Radeon")):
@@ -185,10 +172,11 @@ def get_iree_vulkan_args(device_num=0, extra_args=[]):
return res_vulkan_flag
@functools.cache
def get_iree_vulkan_runtime_flags():
vulkan_runtime_flags = [
f"--vulkan_large_heap_block_size={shark_args.vulkan_large_heap_block_size}",
f"--vulkan_validation_layers={'true' if shark_args.vulkan_validation_layers else 'false'}",
f"--vulkan_vma_allocator={'true' if shark_args.vulkan_vma_allocator else 'false'}",
]
return vulkan_runtime_flags

View File

@@ -14,21 +14,8 @@
import argparse
import os
import shlex
import subprocess
class SplitStrToListAction(argparse.Action):
def __init__(self, option_strings, dest, *args, **kwargs):
super(SplitStrToListAction, self).__init__(
option_strings=option_strings, dest=dest, *args, **kwargs
)
def __call__(self, parser, namespace, values, option_string=None):
del parser, option_string
setattr(namespace, self.dest, shlex.split(values[0]))
parser = argparse.ArgumentParser(description="SHARK runner.")
parser.add_argument(
@@ -37,13 +24,6 @@ parser.add_argument(
default="cpu",
help="Device on which shark_runner runs. options are cpu, cuda, and vulkan",
)
parser.add_argument(
"--additional_compile_args",
default=list(),
nargs=1,
action=SplitStrToListAction,
help="Additional arguments to pass to the compiler. These are appended as the last arguments.",
)
parser.add_argument(
"--enable_tf32",
type=bool,
@@ -134,7 +114,7 @@ parser.add_argument(
"--device_allocator",
type=str,
nargs="*",
default=["caching"],
default=[],
help="Specifies one or more HAL device allocator specs "
"to augment the base device allocator",
choices=["debug", "caching"],
@@ -153,6 +133,13 @@ parser.add_argument(
help="Profiles vulkan device and collects the .rdc info.",
)
parser.add_argument(
"--vulkan_large_heap_block_size",
default="2073741824",
help="Flag for setting VMA preferredLargeHeapBlockSize for "
"vulkan device, default is 4G.",
)
parser.add_argument(
"--vulkan_validation_layers",
default=False,
@@ -160,4 +147,11 @@ parser.add_argument(
help="Flag for disabling vulkan validation layers when benchmarking.",
)
parser.add_argument(
"--vulkan_vma_allocator",
default=True,
action=argparse.BooleanOptionalAction,
help="Flag for enabling / disabling Vulkan VMA Allocator.",
)
shark_args, unknown = parser.parse_known_args()

View File

@@ -13,11 +13,7 @@
# limitations under the License.
from shark.shark_runner import SharkRunner
from shark.iree_utils.compile_utils import (
export_iree_module_to_vmfb,
load_flatbuffer,
get_iree_runtime_config,
)
from shark.iree_utils.compile_utils import export_iree_module_to_vmfb
from shark.iree_utils.benchmark_utils import (
build_benchmark_args,
run_benchmark_module,
@@ -83,39 +79,22 @@ class SharkBenchmarkRunner(SharkRunner):
self.mlir_dialect = mlir_dialect
self.extra_args = extra_args
self.import_args = {}
self.temp_file_to_unlink = None
if not os.path.isfile(mlir_module):
print(
"Warning: Initializing SharkRunner with a mlir string/bytecode object will duplicate the model in RAM at compile time. To avoid this, initialize SharkInference with a path to a MLIR module on your hard disk instead."
)
self.compile_str = True
else:
self.compile_str = False
SharkRunner.__init__(
self,
mlir_module,
device,
self.mlir_dialect,
self.extra_args,
compile_vmfb=False,
compile_vmfb=True,
)
self.vmfb_file = export_iree_module_to_vmfb(
mlir_module,
device,
".",
self.mlir_dialect,
extra_args=self.extra_args,
compile_str=self.compile_str,
)
params = load_flatbuffer(
self.vmfb_file,
device,
mmap=True,
)
self.iree_compilation_module = params["vmfb"]
self.iree_config = params["config"]
self.temp_file_to_unlink = params["temp_file_to_unlink"]
del params
if self.vmfb_file == None:
self.vmfb_file = export_iree_module_to_vmfb(
mlir_module,
device,
".",
self.mlir_dialect,
extra_args=self.extra_args,
)
def setup_cl(self, input_tensors):
self.benchmark_cl = build_benchmark_args(
@@ -132,41 +111,42 @@ class SharkBenchmarkRunner(SharkRunner):
elif self.mlir_dialect in ["mhlo", "tf"]:
return self.benchmark_tf(modelname)
def benchmark_torch(self, modelname, device="cpu"):
def benchmark_torch(self, modelname):
import torch
from tank.model_utils import get_torch_model
# TODO: Pass this as an arg. currently the best way is to setup with BENCHMARK=1 if we want to use torch+cuda, else use cpu.
device = "cuda" if torch.cuda.is_available() else "cpu"
if device == "cuda":
torch.set_default_device("cuda:0")
# if self.enable_tf32:
# torch.backends.cuda.matmul.allow_tf32 = True
if self.device == "cuda":
torch.set_default_tensor_type(torch.cuda.FloatTensor)
if self.enable_tf32:
torch.backends.cuda.matmul.allow_tf32 = True
else:
torch.set_default_dtype(torch.float32)
torch.set_default_device("cpu")
torch_device = torch.device("cuda:0" if device == "cuda" else "cpu")
torch.set_default_tensor_type(torch.FloatTensor)
torch_device = torch.device(
"cuda:0" if self.device == "cuda" else "cpu"
)
HFmodel, input = get_torch_model(modelname, self.import_args)[:2]
frontend_model = HFmodel.model
frontend_model.to(torch_device)
if device == "cuda":
frontend_model.cuda()
input.to(torch.device("cuda:0"))
print(input)
else:
frontend_model.cpu()
input.cpu()
input.to(torch_device)
# TODO: re-enable as soon as pytorch CUDA context issues are resolved
try:
frontend_model = torch.compile(
frontend_model, mode="max-autotune", backend="inductor"
)
except RuntimeError:
frontend_model = HFmodel.model
for i in range(shark_args.num_warmup_iterations):
frontend_model.forward(input)
if device == "cuda":
if self.device == "cuda":
torch.cuda.reset_peak_memory_stats()
begin = time.time()
for i in range(shark_args.num_iterations):
out = frontend_model.forward(input)
end = time.time()
if device == "cuda":
if self.device == "cuda":
stats = torch.cuda.memory_stats()
device_peak_b = stats["allocated_bytes.all.peak"]
frontend_model.to(torch.device("cpu"))
@@ -178,7 +158,7 @@ class SharkBenchmarkRunner(SharkRunner):
print(
f"Torch benchmark:{shark_args.num_iterations/(end-begin)} iter/second, Total Iterations:{shark_args.num_iterations}"
)
if device == "cuda":
if self.device == "cuda":
# Set device to CPU so we don't run into segfaults exiting pytest subprocesses.
torch_device = torch.device("cpu")
return [

View File

@@ -1,7 +1,7 @@
import os
import tempfile
from shark.shark_inference import SharkInference
from shark.shark_importer import import_with_fx, save_mlir
from shark.shark_importer import import_with_fx
import torch
import torch_mlir
from torch_mlir.compiler_utils import run_pipeline_with_repro_report
@@ -11,8 +11,14 @@ from brevitas_examples.llm.llm_quant.quantize import quantize_model
from brevitas_examples.llm.llm_quant.run_utils import get_model_impl
# fmt: off
def quantmatmul_rhs_group_quant〡shape(lhs: List[int], rhs: List[int], rhs_scale: List[int], rhs_zero_point: List[int], rhs_bit_width: int, rhs_group_size: int) -> List[int]:
def brevitasmatmul_rhs_group_quant〡shape(
lhs: List[int],
rhs: List[int],
rhs_scale: List[int],
rhs_zero_point: List[int],
rhs_bit_width: int,
rhs_group_size: int,
) -> List[int]:
if len(lhs) == 3 and len(rhs) == 2:
return [lhs[0], lhs[1], rhs[0]]
elif len(lhs) == 2 and len(rhs) == 2:
@@ -21,21 +27,30 @@ def quantmatmul_rhs_group_quant〡shape(lhs: List[int], rhs: List[int], rhs_s
raise ValueError("Input shapes not supported.")
def quantmatmul_rhs_group_quant〡dtype(lhs_rank_dtype: Tuple[int, int], rhs_rank_dtype: Tuple[int, int], rhs_scale_rank_dtype: Tuple[int, int], rhs_zero_point_rank_dtype: Tuple[int, int], rhs_bit_width: int, rhs_group_size: int) -> int:
def brevitasmatmul_rhs_group_quant〡dtype(
lhs_rank_dtype: Tuple[int, int],
rhs_rank_dtype: Tuple[int, int],
rhs_scale_rank_dtype: Tuple[int, int],
rhs_zero_point_rank_dtype: Tuple[int, int],
rhs_bit_width: int,
rhs_group_size: int,
) -> int:
# output dtype is the dtype of the lhs float input
lhs_rank, lhs_dtype = lhs_rank_dtype
return lhs_dtype
def quantmatmul_rhs_group_quant〡has_value_semantics(lhs, rhs, rhs_scale, rhs_zero_point, rhs_bit_width, rhs_group_size) -> None:
def brevitasmatmul_rhs_group_quant〡has_value_semantics(
lhs, rhs, rhs_scale, rhs_zero_point, rhs_bit_width, rhs_group_size
) -> None:
return
brevitas_matmul_rhs_group_quant_library = [
quantmatmul_rhs_group_quant〡shape,
quantmatmul_rhs_group_quant〡dtype,
quantmatmul_rhs_group_quant〡has_value_semantics]
# fmt: on
brevitasmatmul_rhs_group_quant〡shape,
brevitasmatmul_rhs_group_quant〡dtype,
brevitasmatmul_rhs_group_quant〡has_value_semantics,
]
def load_vmfb(extended_model_name, device, mlir_dialect, extra_args=[]):
@@ -107,7 +122,7 @@ def compile_int_precision(
torchscript_module,
inputs,
output_type="torch",
backend_legal_ops=["quant.matmul_rhs_group_quant"],
backend_legal_ops=["brevitas.matmul_rhs_group_quant"],
extra_library=brevitas_matmul_rhs_group_quant_library,
use_tracing=False,
verbose=False,
@@ -115,7 +130,7 @@ def compile_int_precision(
print(f"[DEBUG] converting torch to linalg")
run_pipeline_with_repro_report(
mlir_module,
"builtin.module(func.func(torch-unpack-quant-tensor),func.func(torch-convert-custom-quant-op),torch-backend-to-linalg-on-tensors-backend-pipeline)",
"builtin.module(func.func(torch-unpack-torch-tensor),torch-backend-to-linalg-on-tensors-backend-pipeline)",
description="Lowering Torch Backend IR -> Linalg-on-Tensors Backend IR",
)
from contextlib import redirect_stdout
@@ -130,17 +145,10 @@ def compile_int_precision(
mlir_module = mlir_module.encode("UTF-8")
mlir_module = BytesIO(mlir_module)
bytecode = mlir_module.read()
bytecode_path = os.path.join(
os.getcwd(), f"{extended_model_name}_linalg.mlirbc"
)
with open(bytecode_path, "wb") as f:
f.write(bytecode)
del bytecode
del mlir_module
print(f"Elided IR written for {extended_model_name}")
return bytecode_path
return bytecode
shark_module = SharkInference(
mlir_module=bytecode_path, device=device, mlir_dialect="tm_tensor"
mlir_module=bytecode, device=device, mlir_dialect="tm_tensor"
)
extra_args = [
"--iree-hal-dump-executable-sources-to=ies",
@@ -155,7 +163,7 @@ def compile_int_precision(
generate_vmfb=generate_vmfb,
extra_args=extra_args,
),
bytecode_path,
bytecode,
)
@@ -208,7 +216,7 @@ def shark_compile_through_fx(
]
else:
(
bytecode,
mlir_module,
_,
) = import_with_fx(
model=model,
@@ -219,11 +227,6 @@ def shark_compile_through_fx(
model_name=extended_model_name,
save_dir=save_dir,
)
mlir_module = save_mlir(
mlir_module=bytecode,
model_name=extended_model_name,
mlir_dialect=mlir_dialect,
)
shark_module = SharkInference(
mlir_module,

View File

@@ -111,20 +111,22 @@ os.makedirs(WORKDIR, exist_ok=True)
def check_dir_exists(model_name, frontend="torch", dynamic=""):
model_dir = os.path.join(WORKDIR, model_name)
# Remove the _tf keyword from end only for non-SD models.
if not any(model in model_name for model in ["clip", "unet", "vae"]):
if frontend in ["tf", "tensorflow"]:
model_name = model_name[:-3]
elif frontend in ["tflite"]:
model_name = model_name[:-7]
elif frontend in ["torch", "pytorch"]:
model_name = model_name[:-6]
model_mlir_file_name = f"{model_name}{dynamic}_{frontend}.mlir"
# Remove the _tf keyword from end.
if frontend in ["tf", "tensorflow"]:
model_name = model_name[:-3]
elif frontend in ["tflite"]:
model_name = model_name[:-7]
elif frontend in ["torch", "pytorch"]:
model_name = model_name[:-6]
if os.path.isdir(model_dir):
if (
os.path.isfile(os.path.join(model_dir, model_mlir_file_name))
os.path.isfile(
os.path.join(
model_dir,
model_name + dynamic + "_" + str(frontend) + ".mlir",
)
)
and os.path.isfile(os.path.join(model_dir, "function_name.npy"))
and os.path.isfile(os.path.join(model_dir, "inputs.npz"))
and os.path.isfile(os.path.join(model_dir, "golden_out.npz"))
@@ -275,11 +277,11 @@ def download_model(
model_dir = os.path.join(WORKDIR, model_dir_name)
tuned_str = "" if tuned is None else "_" + tuned
suffix = f"{dyn_str}_{frontend}{tuned_str}.mlir"
mlir_filename = os.path.join(model_dir, model_name + suffix)
filename = os.path.join(model_dir, model_name + suffix)
print(
f"Verifying that model artifacts were downloaded successfully to {mlir_filename}..."
f"Verifying that model artifacts were downloaded successfully to {filename}..."
)
if not os.path.exists(mlir_filename):
if not os.path.exists(filename):
from tank.generate_sharktank import gen_shark_files
print(
@@ -287,11 +289,13 @@ def download_model(
)
gen_shark_files(model_name, frontend, WORKDIR, import_args)
assert os.path.exists(mlir_filename), f"MLIR not found at {mlir_filename}"
assert os.path.exists(filename), f"MLIR not found at {filename}"
with open(filename, mode="rb") as f:
mlir_file = f.read()
function_name = str(np.load(os.path.join(model_dir, "function_name.npy")))
inputs = np.load(os.path.join(model_dir, "inputs.npz"))
golden_out = np.load(os.path.join(model_dir, "golden_out.npz"))
inputs_tuple = tuple([inputs[key] for key in inputs])
golden_out_tuple = tuple([golden_out[key] for key in golden_out])
return mlir_filename, function_name, inputs_tuple, golden_out_tuple
return mlir_file, function_name, inputs_tuple, golden_out_tuple

View File

@@ -1,6 +1,6 @@
from typing import Any, Dict, List, Tuple
from collections import defaultdict
from shark.shark_importer import import_with_fx, save_mlir
from shark.shark_importer import import_with_fx
import torchvision.models as models
import copy
import io
@@ -20,16 +20,10 @@ def shark_backend(fx_g: torch.fx.GraphModule, inputs, device: str = "cpu"):
bytecode_stream = io.BytesIO()
mlir_module.operation.write_bytecode(bytecode_stream)
bytecode = bytecode_stream.getvalue()
bytecode_path = save_mlir(
bytecode,
model_name="shark_eager_module",
frontend="torch",
mlir_dialect="tm_tensor",
)
from shark.shark_inference import SharkInference
shark_module = SharkInference(
mlir_module=bytecode_path,
mlir_module=bytecode,
device=device,
mlir_dialect="tm_tensor",
)

View File

@@ -1,10 +1,8 @@
import re
import json
import numpy as np
import torch_mlir
from iree.compiler import compile_file
from shark.shark_importer import import_with_fx, get_f16_inputs, save_mlir
from iree.compiler import compile_str
from shark.shark_importer import import_with_fx, get_f16_inputs
class GenerateConfigFile:
@@ -13,7 +11,6 @@ class GenerateConfigFile:
model,
num_sharding_stages: int,
sharding_stages_id: list[str],
units_in_each_stage: list[int],
model_input=None,
config_file_path="model_config.json",
):
@@ -25,16 +22,13 @@ class GenerateConfigFile:
), "Number of sharding stages should be equal to the list of their ID"
self.model_input = model_input
self.config_file_path = config_file_path
# (Nithin) this is a quick fix - revisit and rewrite
self.units_in_each_stage = np.array(units_in_each_stage)
self.track_loop = np.zeros(len(self.sharding_stages_id)).astype(int)
def split_into_dispatches(
self,
backend,
fx_tracing_required=False,
fx_tracing_required=True,
f16_model=False,
torch_mlir_tracing=True,
torch_mlir_tracing=False,
):
graph_for_compilation = self.model
if fx_tracing_required:
@@ -54,15 +48,9 @@ class GenerateConfigFile:
verbose=False,
)
module = module.operation.get_asm(large_elements_limit=4)
module_file = save_mlir(
module,
model_name="module_pre_split",
frontend="torch",
mlir_dialect="linalg",
)
compiled_module_str = str(
compile_file(
module_file,
compile_str(
str(module),
target_backends=[backend],
extra_args=[
"--compile-to=flow",
@@ -107,17 +95,7 @@ class GenerateConfigFile:
if substring_before_final_period in model_dictionary:
del model_dictionary[substring_before_final_period]
# layer_dict = {n: "None" for n in self.sharding_stages_id}
# By default embed increasing device id's for each layer
increasing_wraparound_idx_list = (
self.track_loop % self.units_in_each_stage
)
layer_dict = {
n: int(increasing_wraparound_idx_list[idx][0][0])
for idx, n in enumerate(self.sharding_stages_id)
}
self.track_loop += 1
layer_dict = {n: "None" for n in self.sharding_stages_id}
model_dictionary[name] = layer_dict
self.generate_json(model_dictionary)
@@ -125,29 +103,3 @@ class GenerateConfigFile:
def generate_json(self, artifacts):
with open(self.config_file_path, "w") as outfile:
json.dump(artifacts, outfile)
if __name__ == "__main__":
import torch
from transformers import AutoTokenizer
hf_model_path = "TheBloke/vicuna-7B-1.1-HF"
tokenizer = AutoTokenizer.from_pretrained(hf_model_path, use_fast=False)
compilation_prompt = "".join(["0" for _ in range(17)])
compilation_input_ids = tokenizer(
compilation_prompt,
return_tensors="pt",
).input_ids
compilation_input_ids = torch.tensor(compilation_input_ids).reshape(
[1, 19]
)
firstVicunaCompileInput = (compilation_input_ids,)
from apps.language_models.src.model_wrappers.vicuna_model import (
FirstVicuna,
SecondVicuna7B,
CombinedModel,
)
model = CombinedModel()
c = GenerateConfigFile(model, 1, ["gpu_id"], firstVicunaCompileInput)
c.split_into_layers()

View File

@@ -451,108 +451,6 @@ def transform_fx(fx_g, quantized=False):
fx_g.graph.lint()
def gptq_transforms(fx_g):
import torch
for node in fx_g.graph.nodes:
if node.op == "call_function":
if node.target in [
torch.ops.aten.arange,
torch.ops.aten.empty,
torch.ops.aten.ones,
torch.ops.aten._to_copy,
]:
if node.kwargs.get("device") == torch.device(device="cuda:0"):
updated_kwargs = node.kwargs.copy()
updated_kwargs["device"] = torch.device(device="cpu")
node.kwargs = updated_kwargs
if node.target in [
torch.ops.aten._to_copy,
]:
if node.kwargs.get("dtype") == torch.bfloat16:
updated_kwargs = node.kwargs.copy()
updated_kwargs["dtype"] = torch.float16
node.kwargs = updated_kwargs
# Inputs of aten.native_layer_norm should be upcasted to fp32.
if node.target in [torch.ops.aten.native_layer_norm]:
with fx_g.graph.inserting_before(node):
new_node_arg0 = fx_g.graph.call_function(
torch.ops.prims.convert_element_type,
args=(node.args[0], torch.float32),
kwargs={},
)
node.args = (
new_node_arg0,
node.args[1],
node.args[2],
node.args[3],
node.args[4],
)
# Inputs of aten.mm should be upcasted to fp32.
if node.target in [torch.ops.aten.mm]:
with fx_g.graph.inserting_before(node):
new_node_arg0 = fx_g.graph.call_function(
torch.ops.prims.convert_element_type,
args=(node.args[0], torch.float32),
kwargs={},
)
new_node_arg1 = fx_g.graph.call_function(
torch.ops.prims.convert_element_type,
args=(node.args[1], torch.float32),
kwargs={},
)
node.args = (new_node_arg0, new_node_arg1)
# Outputs of aten.mm should be downcasted to fp16.
if type(node.args[0]) == torch.fx.node.Node and node.args[
0
].target in [torch.ops.aten.mm]:
with fx_g.graph.inserting_before(node):
tmp = node.args[0]
new_node = fx_g.graph.call_function(
torch.ops.aten._to_copy,
args=(node.args[0],),
kwargs={"dtype": torch.float16},
)
node.args[0].append(new_node)
node.args[0].replace_all_uses_with(new_node)
new_node.args = (tmp,)
new_node.kwargs = {"dtype": torch.float16}
# Inputs of aten._softmax should be upcasted to fp32.
if node.target in [torch.ops.aten._softmax]:
with fx_g.graph.inserting_before(node):
new_node_arg0 = fx_g.graph.call_function(
torch.ops.prims.convert_element_type,
args=(node.args[0], torch.float32),
kwargs={},
)
node.args = (new_node_arg0, node.args[1], node.args[2])
# Outputs of aten._softmax should be downcasted to fp16.
if (
type(node.args[0]) == torch.fx.node.Node
and node.args[0].target in [torch.ops.aten._softmax]
and node.target in [torch.ops.aten.expand]
):
with fx_g.graph.inserting_before(node):
tmp = node.args[0]
new_node = fx_g.graph.call_function(
torch.ops.aten._to_copy,
args=(node.args[0],),
kwargs={"dtype": torch.float16},
)
node.args[0].append(new_node)
node.args[0].replace_all_uses_with(new_node)
new_node.args = (tmp,)
new_node.kwargs = {"dtype": torch.float16}
fx_g.graph.lint()
# Doesn't replace the None type.
def change_fx_graph_return_to_tuple(fx_g):
for node in fx_g.graph.nodes:
@@ -606,12 +504,27 @@ def import_with_fx(
is_dynamic=False,
tracing_required=False,
precision="fp32",
is_gptq=False,
):
import torch
from torch.fx.experimental.proxy_tensor import make_fx
from torch._decomp import get_decompositions
from typing import List
from brevitas_examples.llm.llm_quant.export import (
block_quant_layer_level_manager,
)
from brevitas_examples.llm.llm_quant.export import (
brevitas_layer_export_mode,
)
from brevitas_examples.llm.llm_quant.sharded_mlir_group_export import (
LinearWeightBlockQuantHandlerFwd,
)
from brevitas_examples.llm.llm_quant.export import replace_call_fn_target
from brevitas_examples.llm.llm_quant.sharded_mlir_group_export import (
matmul_rhs_group_quant_placeholder,
)
from brevitas.backport.fx.experimental.proxy_tensor import (
make_fx as brevitas_make_fx,
)
golden_values = None
if debug:
@@ -683,30 +596,8 @@ def import_with_fx(
torch.ops.aten.native_layer_norm,
torch.ops.aten.masked_fill.Tensor,
torch.ops.aten.masked_fill.Scalar,
torch.ops.aten._scaled_dot_product_flash_attention.default,
torch.ops.aten.index_add,
torch.ops.aten.index_add_,
]
if precision in ["int4", "int8"] and not is_gptq:
from brevitas_examples.llm.llm_quant.export import (
block_quant_layer_level_manager,
)
from brevitas_examples.llm.llm_quant.export import (
brevitas_layer_export_mode,
)
from brevitas_examples.llm.llm_quant.sharded_mlir_group_export import (
LinearWeightBlockQuantHandlerFwd,
)
from brevitas_examples.llm.llm_quant.export import (
replace_call_fn_target,
)
from brevitas_examples.llm.llm_quant.sharded_mlir_group_export import (
matmul_rhs_group_quant_placeholder,
)
from brevitas.backport.fx.experimental.proxy_tensor import (
make_fx as brevitas_make_fx,
)
if precision in ["int4", "int8"]:
export_context_manager = brevitas_layer_export_mode
export_class = block_quant_layer_level_manager(
export_handlers=[LinearWeightBlockQuantHandlerFwd]
@@ -721,7 +612,7 @@ def import_with_fx(
replace_call_fn_target(
fx_g,
src=matmul_rhs_group_quant_placeholder,
target=torch.ops.quant.matmul_rhs_group_quant,
target=torch.ops.brevitas.matmul_rhs_group_quant,
)
fx_g.recompile()
@@ -756,10 +647,6 @@ def import_with_fx(
add_upcast(fx_g)
fx_g.recompile()
if is_gptq:
gptq_transforms(fx_g)
fx_g.recompile()
if mlir_type == "fx":
return fx_g
@@ -790,27 +677,5 @@ def import_with_fx(
)
return mlir_module, func_name
mlir_module, func_name = mlir_importer.import_mlir(mlir_type=mlir_type)
mlir_module, func_name = mlir_importer.import_mlir()
return mlir_module, func_name
# Saves a .mlir module python object to the directory 'dir' with 'model_name' and returns a path to the saved file.
def save_mlir(
mlir_module,
model_name,
mlir_dialect="linalg",
frontend="torch",
dir=tempfile.gettempdir(),
):
model_name_mlir = (
model_name + "_" + frontend + "_" + mlir_dialect + ".mlir"
)
if dir == "":
dir = tempfile.gettempdir()
mlir_path = os.path.join(dir, model_name_mlir)
print(f"saving {model_name_mlir} to {dir}")
if frontend == "torch":
with open(mlir_path, "wb") as mlir_file:
mlir_file.write(mlir_module)
return mlir_path

View File

@@ -39,7 +39,7 @@ class SharkInference:
Attributes
----------
mlir_module : str
mlir_module or path represented in string; modules from torch-mlir are serialized in bytecode format.
mlir_module represented in string; modules from torch-mlir are serialized in bytecode format.
device : str
device to execute the mlir_module on.
currently supports cpu, cuda, vulkan, and metal backends.
@@ -65,7 +65,7 @@ class SharkInference:
def __init__(
self,
mlir_module,
mlir_module: bytes,
device: str = "none",
mlir_dialect: str = "linalg",
is_benchmark: bool = False,
@@ -75,14 +75,6 @@ class SharkInference:
mmap: bool = True,
):
self.mlir_module = mlir_module
if mlir_module is not None:
if mlir_module and not os.path.isfile(mlir_module):
print(
"Warning: Initializing SharkInference with a mlir string/bytecode object will duplicate the model in RAM at compile time. To avoid this, initialize SharkInference with a path to a MLIR module on your hard disk instead."
)
self.compile_str = True
else:
self.compile_str = False
self.device = shark_args.device if device == "none" else device
self.mlir_dialect = mlir_dialect
self.is_benchmark = is_benchmark
@@ -149,10 +141,6 @@ class SharkInference:
def __call__(self, function_name: str, inputs: tuple, send_to_host=True):
return self.shark_runner.run(function_name, inputs, send_to_host)
# forward function.
def forward(self, inputs: tuple, send_to_host=True):
return self.shark_runner.run("forward", inputs, send_to_host)
# Get all function names defined within the compiled module.
def get_functions_in_module(self):
return self.shark_runner.get_functions_in_module()
@@ -200,9 +188,7 @@ class SharkInference:
# TODO: Instead of passing directory and having names decided by the module
# , user may want to save the module with manual names.
def save_module(
self, dir=os.getcwd(), module_name=None, extra_args=[], debug=False
):
def save_module(self, dir=os.getcwd(), module_name=None, extra_args=[]):
return export_iree_module_to_vmfb(
self.mlir_module,
self.device,
@@ -210,8 +196,6 @@ class SharkInference:
self.mlir_dialect,
module_name=module_name,
extra_args=extra_args,
debug=debug,
compile_str=self.compile_str,
)
# load and return the module.

View File

@@ -45,7 +45,7 @@ class SharkRunner:
Attributes
----------
mlir_module : str
mlir_module path, string, or bytecode.
mlir_module represented in string.
device : str
device to execute the mlir_module on.
currently supports cpu, cuda, vulkan, and metal backends.
@@ -74,14 +74,6 @@ class SharkRunner:
device_idx: int = None,
):
self.mlir_module = mlir_module
if self.mlir_module is not None:
if not os.path.isfile(mlir_module):
print(
"Warning: Initializing SharkRunner with a mlir string/bytecode object will duplicate the model in RAM at compile time. To avoid this, initialize SharkInference with a path to a MLIR module on your hard disk instead."
)
self.compile_str = True
else:
self.compile_str = False
self.device = shark_args.device if device == "none" else device
self.mlir_dialect = mlir_dialect
self.extra_args = extra_args
@@ -99,7 +91,6 @@ class SharkRunner:
self.mlir_dialect,
extra_args=self.extra_args,
device_idx=self.device_idx,
compile_str=self.compile_str,
)
self.iree_compilation_module = params["vmfb"]
self.iree_config = params["config"]

View File

@@ -15,7 +15,7 @@
from shark.parser import shark_args
from shark.shark_runner import SharkRunner
from shark.backward_makefx import MakeFxModule
from shark.shark_importer import import_with_fx, save_mlir
from shark.shark_importer import import_with_fx
import numpy as np
from tqdm import tqdm
import sys
@@ -69,7 +69,7 @@ class SharkTrainer:
self.frontend = frontend
# Training function is needed in the case of torch_fn.
def compile(self, training_fn=None, mlir_type="linalg", extra_args=[]):
def compile(self, training_fn=None, extra_args=[]):
if self.frontend in ["torch", "pytorch"]:
packed_inputs = (
dict(self.model.named_parameters()),
@@ -77,18 +77,7 @@ class SharkTrainer:
tuple(self.input),
)
mlir_module, func_name = import_with_fx(
training_fn,
packed_inputs,
False,
[],
training=True,
mlir_type=mlir_type,
)
mlir_module = save_mlir(
mlir_module,
model_name="shark_model",
frontend="torch",
mlir_dialect=mlir_type,
training_fn, packed_inputs, False, [], training=True
)
self.shark_runner = SharkRunner(
mlir_module,

View File

@@ -1,6 +1,25 @@
resnet50,stablehlo,tf,1e-2,1e-3,default,nhcw-nhwc,False,False,False,"","macos"
albert-base-v2,stablehlo,tf,1e-2,1e-2,default,None,False,False,False,"",""
roberta-base,stablehlo,tf,1e-02,1e-3,default,nhcw-nhwc,True,True,True,"","macos"
bert-base-uncased,stablehlo,tf,1e-2,1e-3,default,None,False,False,False,"","enabled_windows"
camembert-base,stablehlo,tf,1e-2,1e-3,default,None,True,True,True,"",""
dbmdz/convbert-base-turkish-cased,stablehlo,tf,1e-2,1e-3,default,nhcw-nhwc,True,True,False,"https://github.com/iree-org/iree/issues/9971",""
distilbert-base-uncased,stablehlo,tf,1e-2,1e-3,default,None,False,False,False,"",""
facebook/convnext-tiny-224,stablehlo,tf,1e-2,1e-3,tf_vit,nhcw-nhwc,True,True,False,"https://github.com/nod-ai/SHARK/issues/311 & https://github.com/nod-ai/SHARK/issues/342","macos"
funnel-transformer/small,stablehlo,tf,1e-2,1e-3,default,None,True,True,False,"https://github.com/nod-ai/SHARK/issues/201",""
google/electra-small-discriminator,stablehlo,tf,1e-2,1e-3,default,None,False,False,False,"",""
google/mobilebert-uncased,stablehlo,tf,1e-2,1e-3,default,None,True,False,False,"Fails during iree-compile","macos"
google/vit-base-patch16-224,stablehlo,tf,1e-2,1e-3,tf_vit,nhcw-nhwc,False,False,False,"",""
microsoft/MiniLM-L12-H384-uncased,stablehlo,tf,1e-2,1e-3,tf_hf,None,True,False,False,"Fails during iree-compile.",""
microsoft/layoutlm-base-uncased,stablehlo,tf,1e-2,1e-3,default,None,False,False,False,"",""
microsoft/mpnet-base,stablehlo,tf,1e-2,1e-2,default,None,True,True,True,"",""
albert-base-v2,linalg,torch,1e-2,1e-3,default,None,True,True,True,"issue with aten.tanh in torch-mlir",""
alexnet,linalg,torch,1e-2,1e-3,default,None,True,True,False,"https://github.com/nod-ai/SHARK/issues/879",""
bert-base-cased,linalg,torch,1e-2,1e-3,default,None,False,True,False,"",""
bert-base-uncased,linalg,torch,1e-2,1e-3,default,None,False,True,False,"",""
bert-base-uncased_fp16,linalg,torch,1e-1,1e-1,default,None,True,True,True,"",""
bert-large-uncased,linalg,torch,1e-2,1e-3,default,None,False,True,False,"",""
bert-large-uncased,stablehlo,tf,1e-2,1e-3,default,None,False,False,False,"",""
facebook/deit-small-distilled-patch16-224,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,False,True,False,"Fails during iree-compile.",""
google/vit-base-patch16-224,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,False,True,False,"https://github.com/nod-ai/SHARK/issues/311",""
microsoft/beit-base-patch16-224-pt22k-ft22k,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,False,True,False,"https://github.com/nod-ai/SHARK/issues/390","macos"
@@ -11,11 +30,18 @@ nvidia/mit-b0,linalg,torch,1e-2,1e-3,default,None,True,True,True,"https://github
resnet101,linalg,torch,1e-2,1e-3,default,nhcw-nhwc/img2col,True,False,False,"","macos"
resnet18,linalg,torch,1e-2,1e-3,default,None,True,True,False,"","macos"
resnet50,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,False,False,False,"","macos"
resnet50_fp16,linalg,torch,1e-2,1e-2,default,nhcw-nhwc/img2col,True,True,True,"Numerics issues, awaiting cuda-independent fp16 integration",""
resnet50_fp16,linalg,torch,1e-2,1e-2,default,nhcw-nhwc/img2col,True,False,True,"",""
squeezenet1_0,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,False,False,False,"","macos"
wide_resnet50_2,linalg,torch,1e-2,1e-3,default,nhcw-nhwc/img2col,True,False,False,"","macos"
efficientnet-v2-s,stablehlo,tf,1e-02,1e-3,default,nhcw-nhwc,False,False,False,"","macos"
mnasnet1_0,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,True,True,True,"","macos"
efficientnet_b0,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,True,True,True,"https://github.com/nod-ai/SHARK/issues/1487","macos"
efficientnet_b7,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,True,True,True,"https://github.com/nod-ai/SHARK/issues/1487","macos"
efficientnet_b0,stablehlo,tf,1e-2,1e-3,default,nhcw-nhwc,False,False,False,"",""
efficientnet_b7,stablehlo,tf,1e-2,1e-3,default,nhcw-nhwc,False,False,False,"Fails on MacOS builder, VK device lost","macos"
gpt2,stablehlo,tf,1e-2,1e-3,default,None,True,False,False,"","macos"
t5-base,linalg,torch,1e-2,1e-3,default,None,True,True,True,"Inputs for seq2seq models in torch currently unsupported.","macos"
t5-base,stablehlo,tf,1e-2,1e-3,default,None,False,False,False,"","macos"
t5-large,linalg,torch,1e-2,1e-3,default,None,True,True,True,"Inputs for seq2seq models in torch currently unsupported","macos"
t5-large,stablehlo,tf,1e-2,1e-3,default,None,False,False,False,"","macos"
stabilityai/stable-diffusion-2-1-base,linalg,torch,1e-3,1e-3,default,None,True,False,False,"","macos"
1 bert-base-uncased resnet50 linalg stablehlo torch tf 1e-2 1e-3 default None nhcw-nhwc False True False False macos
1 resnet50 stablehlo tf 1e-2 1e-3 default nhcw-nhwc False False False macos
2 albert-base-v2 stablehlo tf 1e-2 1e-2 default None False False False
3 roberta-base stablehlo tf 1e-02 1e-3 default nhcw-nhwc True True True macos
4 bert-base-uncased stablehlo tf 1e-2 1e-3 default None False False False enabled_windows
5 camembert-base stablehlo tf 1e-2 1e-3 default None True True True
6 dbmdz/convbert-base-turkish-cased stablehlo tf 1e-2 1e-3 default nhcw-nhwc True True False https://github.com/iree-org/iree/issues/9971
7 distilbert-base-uncased stablehlo tf 1e-2 1e-3 default None False False False
8 facebook/convnext-tiny-224 stablehlo tf 1e-2 1e-3 tf_vit nhcw-nhwc True True False https://github.com/nod-ai/SHARK/issues/311 & https://github.com/nod-ai/SHARK/issues/342 macos
9 funnel-transformer/small stablehlo tf 1e-2 1e-3 default None True True False https://github.com/nod-ai/SHARK/issues/201
10 google/electra-small-discriminator stablehlo tf 1e-2 1e-3 default None False False False
11 google/mobilebert-uncased stablehlo tf 1e-2 1e-3 default None True False False Fails during iree-compile macos
12 google/vit-base-patch16-224 stablehlo tf 1e-2 1e-3 tf_vit nhcw-nhwc False False False
13 microsoft/MiniLM-L12-H384-uncased stablehlo tf 1e-2 1e-3 tf_hf None True False False Fails during iree-compile.
14 microsoft/layoutlm-base-uncased stablehlo tf 1e-2 1e-3 default None False False False
15 microsoft/mpnet-base stablehlo tf 1e-2 1e-2 default None True True True
16 albert-base-v2 linalg torch 1e-2 1e-3 default None True True True issue with aten.tanh in torch-mlir
17 alexnet linalg torch 1e-2 1e-3 default None True True False https://github.com/nod-ai/SHARK/issues/879
18 bert-base-cased linalg torch 1e-2 1e-3 default None False True False
19 bert-base-uncased bert-base-uncased linalg linalg torch torch 1e-2 1e-3 default None None False True False True False
20 bert-base-uncased_fp16 bert-base-uncased_fp16 linalg linalg torch torch 1e-1 1e-1 default None None True True True True
21 bert-large-uncased bert-large-uncased linalg linalg torch torch 1e-2 1e-3 default None None False True False True False
22 bert-large-uncased stablehlo tf 1e-2 1e-3 default None False False False
23 facebook/deit-small-distilled-patch16-224 facebook/deit-small-distilled-patch16-224 linalg linalg torch torch 1e-2 1e-3 default nhcw-nhwc nhcw-nhwc False True False True False Fails during iree-compile.
24 google/vit-base-patch16-224 google/vit-base-patch16-224 linalg linalg torch torch 1e-2 1e-3 default nhcw-nhwc nhcw-nhwc False True False True False https://github.com/nod-ai/SHARK/issues/311
25 microsoft/beit-base-patch16-224-pt22k-ft22k microsoft/beit-base-patch16-224-pt22k-ft22k linalg linalg torch torch 1e-2 1e-3 default nhcw-nhwc nhcw-nhwc False True False True False https://github.com/nod-ai/SHARK/issues/390 macos macos
30 resnet101 resnet101 linalg linalg torch torch 1e-2 1e-3 default nhcw-nhwc/img2col nhcw-nhwc/img2col True False False False macos macos
31 resnet18 resnet18 linalg linalg torch torch 1e-2 1e-3 default None None True True False True False macos macos
32 resnet50 resnet50 linalg linalg torch torch 1e-2 1e-3 default nhcw-nhwc nhcw-nhwc False False False False macos macos
33 resnet50_fp16 resnet50_fp16 linalg linalg torch torch 1e-2 1e-2 default nhcw-nhwc/img2col nhcw-nhwc/img2col True True True False True Numerics issues, awaiting cuda-independent fp16 integration
34 squeezenet1_0 squeezenet1_0 linalg linalg torch torch 1e-2 1e-3 default nhcw-nhwc nhcw-nhwc False False False False macos macos
35 wide_resnet50_2 wide_resnet50_2 linalg linalg torch torch 1e-2 1e-3 default nhcw-nhwc/img2col nhcw-nhwc/img2col True False False False macos macos
36 efficientnet-v2-s stablehlo tf 1e-02 1e-3 default nhcw-nhwc False False False macos
37 mnasnet1_0 mnasnet1_0 linalg linalg torch torch 1e-2 1e-3 default nhcw-nhwc nhcw-nhwc True True True True macos macos
38 efficientnet_b0 efficientnet_b0 linalg linalg torch torch 1e-2 1e-3 default nhcw-nhwc nhcw-nhwc True True True True https://github.com/nod-ai/SHARK/issues/1487 macos macos
39 efficientnet_b7 efficientnet_b7 linalg linalg torch torch 1e-2 1e-3 default nhcw-nhwc nhcw-nhwc True True True True https://github.com/nod-ai/SHARK/issues/1487 macos macos
40 efficientnet_b0 stablehlo tf 1e-2 1e-3 default nhcw-nhwc False False False
41 efficientnet_b7 stablehlo tf 1e-2 1e-3 default nhcw-nhwc False False False Fails on MacOS builder, VK device lost macos
42 gpt2 stablehlo tf 1e-2 1e-3 default None True False False macos
43 t5-base t5-base linalg linalg torch torch 1e-2 1e-3 default None None True True True True Inputs for seq2seq models in torch currently unsupported. macos macos
44 t5-base stablehlo tf 1e-2 1e-3 default None False False False macos
45 t5-large t5-large linalg linalg torch torch 1e-2 1e-3 default None None True True True True Inputs for seq2seq models in torch currently unsupported macos macos
46 t5-large stablehlo tf 1e-2 1e-3 default None False False False macos
47 stabilityai/stable-diffusion-2-1-base linalg torch 1e-3 1e-3 default None True False False macos

View File

@@ -85,6 +85,8 @@ if __name__ == "__main__":
args = [
"--iree-llvmcpu-target-cpu-features=host",
"--iree-mhlo-demote-i64-to-i32=false",
"--iree-stream-resource-index-bits=64",
"--iree-vm-target-index-bits=64",
]
backend_config = "dylib"
# backend = "cuda"

View File

@@ -1,26 +1,3 @@
# Run OPT for sentence completion through SHARK
# Running Different OPT Variants
From base SHARK directory, follow instructions to set up a virtual environment with SHARK. (`./setup_venv.sh` or `./setup_venv.ps1`)
Then, you may run opt_causallm.py to get a very simple sentence completion application running through SHARK
```
python opt_causallm.py
```
# Run OPT performance comparison on SHARK vs. PyTorch
```
python opt_perf_comparison.py --max-seq-len=512 --model-name=facebook/opt-1.3b \
--platform=shark
```
Any OPT model from huggingface should work with this script, and you can choose between `--platform=shark` or `--platform=huggingface` to generate benchmarks of OPT inference on SHARK / PyTorch.
# Run a small suite of OPT models through the benchmark script
```
python opt_perf_comparison_batch.py
```
This script will run benchmarks from a suite of OPT configurations:
- Sequence Lengths: 32, 128, 256, 512
- Parameter Counts: 125m, 350m, 1.3b
note: Most of these scripts are written for use on CPU, as perf comparisons against pytorch can be problematic across platforms otherwise.
To run different sizes of OPT, change the string `OPT_MODEL` string in `opt_torch_test.py`. The default is 350m parameters. 66b cases also exist in the file, simply uncomment the test cases.

View File

@@ -36,7 +36,9 @@ def create_module(model_name, tokenizer, device):
mlir_path = f"./{OPT_FS_NAME}_causallm_{MAX_SEQUENCE_LENGTH}_torch.mlir"
if os.path.isfile(mlir_path):
print(f"Found .mlir from {mlir_path}")
with open(mlir_path, "r") as f:
model_mlir = f.read()
print(f"Loaded .mlir from {mlir_path}")
else:
(model_mlir, func_name) = import_with_fx(
model=opt_model,
@@ -48,17 +50,16 @@ def create_module(model_name, tokenizer, device):
with open(mlir_path, "w") as f:
f.write(model_mlir)
print(f"Saved mlir at {mlir_path}")
del model_mlir
shark_module = SharkInference(
mlir_path,
model_mlir,
device=device,
mlir_dialect="tm_tensor",
is_benchmark=False,
)
vmfb_name = f"{OPT_FS_NAME}_causallm_{MAX_SEQUENCE_LENGTH}_torch_{device}"
shark_module.save_module(module_name=vmfb_name, debug=False)
shark_module.save_module(module_name=vmfb_name)
vmfb_path = vmfb_name + ".vmfb"
return vmfb_path

View File

@@ -6,7 +6,7 @@ import numpy as np
from shark_opt_wrapper import OPTForCausalLMModel
from shark.iree_utils._common import check_device_drivers, device_driver_info
from shark.shark_inference import SharkInference
from shark.shark_importer import import_with_fx, save_mlir
from shark.shark_importer import import_with_fx
from transformers import AutoTokenizer, OPTForCausalLM
OPT_MODEL = "facebook/opt-1.3b"
@@ -57,10 +57,9 @@ class OPTModuleTester:
with open(mlir_path, "w") as f:
f.write(mlir_module)
print(f"Saved mlir at {mlir_path}")
del mlir_module
shark_module = SharkInference(
mlir_path,
mlir_module,
device=device,
mlir_dialect="tm_tensor",
is_benchmark=self.benchmark,

View File

@@ -1,45 +1,18 @@
"""
Script for comparing OPT model performance between SHARK and Huggingface
PyTorch.
Usage Example:
python opt_perf_comparison.py --max-seq-len=32 --model-name=facebook/opt-125m \
--platform=shark
python opt_perf_comparison.py --max-seq-len=512 --model-name=facebook/opt-1.3b \
--platform=shark
See parse_args() below for command line argument usage.
"""
import argparse
import collections
import json
import os
import psutil
import time
from typing import Tuple
import os
from shark.shark_inference import SharkInference
from shark.shark_importer import import_with_fx
from transformers import AutoTokenizer, OPTForCausalLM
from shark_opt_wrapper import OPTForCausalLMModel
MODEL_NAME = "facebook/opt-1.3b"
OPT_MODELNAME = "opt-1.3b"
OPT_FS_NAME = "opt_1-3b"
MAX_SEQUENCE_LENGTH = 512
DEVICE = "cpu"
PLATFORM_SHARK = "shark"
PLATFORM_HUGGINGFACE = "huggingface"
# Dict keys for reports.
REPORT_PLATFORM = "platform"
REPORT_MODEL_NAME = "model"
REPORT_MAX_SEQ_LEN = "max_seq_len"
REPORT_LOAD_TIME = "load_time_sec"
REPORT_RUN_TIME = "run_time_sec"
REPORT_LOAD_PHYSICAL_MEMORY_MB = "load_physical_MB"
REPORT_LOAD_VIRTUAL_MEMORY_MB = "load_virtual_MB"
REPORT_RUN_PHYSICAL_MEMORY_MB = "run_physical_MB"
REPORT_RUN_VIRTUAL_MEMORY_MB = "run_virtual_MB"
PROMPTS = [
"What is the meaning of life?",
@@ -57,27 +30,15 @@ PROMPTS = [
ModelWrapper = collections.namedtuple("ModelWrapper", ["model", "tokenizer"])
def get_memory_info():
pid = os.getpid()
process = psutil.Process(pid)
return process.memory_info()
def create_vmfb_module(
model_name: str,
tokenizer,
device: str,
max_seq_len: int,
recompile_shark: bool,
):
opt_base_model = OPTForCausalLM.from_pretrained(model_name)
def create_vmfb_module(model_name, tokenizer, device):
opt_base_model = OPTForCausalLM.from_pretrained("facebook/" + model_name)
opt_base_model.eval()
opt_model = OPTForCausalLMModel(opt_base_model)
encoded_inputs = tokenizer(
PROMPTS[0],
"What is the meaning of life?",
padding="max_length",
truncation=True,
max_length=max_seq_len,
max_length=MAX_SEQUENCE_LENGTH,
return_tensors="pt",
)
inputs = (
@@ -87,16 +48,8 @@ def create_vmfb_module(
# np.save("model_inputs_0.npy", inputs[0])
# np.save("model_inputs_1.npy", inputs[1])
opt_fs_name = get_opt_fs_name(model_name)
mlir_path = f"./{opt_fs_name}_causallm_{max_seq_len}_torch.mlir"
# If MLIR has already been loaded and recompilation is not requested, use
# the loaded MLIR file.
has_mlir = os.path.isfile(mlir_path)
# The purpose of recompile_shark is to measure compilation time; the
# compilation time can be correctly measured only when MLIR has already been
# loaded.
assert not recompile_shark or has_mlir
if has_mlir:
mlir_path = f"./{OPT_FS_NAME}_causallm_{MAX_SEQUENCE_LENGTH}_torch.mlir"
if os.path.isfile(mlir_path):
with open(mlir_path, "r") as f:
model_mlir = f.read()
print(f"Loaded .mlir from {mlir_path}")
@@ -105,7 +58,7 @@ def create_vmfb_module(
model=opt_model,
inputs=inputs,
is_f16=False,
model_name=opt_fs_name,
model_name=OPT_FS_NAME,
return_str=True,
)
with open(mlir_path, "w") as f:
@@ -119,25 +72,18 @@ def create_vmfb_module(
is_benchmark=False,
)
vmfb_name = (
f"{opt_fs_name}_causallm_{max_seq_len}_torch_{DEVICE}_tiled_ukernels"
)
vmfb_name = f"{OPT_FS_NAME}_causallm_{MAX_SEQUENCE_LENGTH}_torch_{DEVICE}_tiled_ukernels"
shark_module.save_module(module_name=vmfb_name)
vmfb_path = vmfb_name + ".vmfb"
return vmfb_path
def load_shark_model(
model_name: str, max_seq_len: int, recompile_shark: bool
) -> ModelWrapper:
opt_fs_name = get_opt_fs_name(model_name)
vmfb_name = f"{opt_fs_name}_causallm_{max_seq_len}_torch_{DEVICE}_tiled_ukernels.vmfb"
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)
if recompile_shark or not os.path.isfile(vmfb_name):
def load_shark_model() -> ModelWrapper:
vmfb_name = f"{OPT_FS_NAME}_causallm_{MAX_SEQUENCE_LENGTH}_torch_{DEVICE}_tiled_ukernels.vmfb"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=False)
if not os.path.isfile(vmfb_name):
print(f"vmfb not found. compiling and saving to {vmfb_name}")
create_vmfb_module(
model_name, tokenizer, DEVICE, max_seq_len, recompile_shark
)
create_vmfb_module(OPT_MODELNAME, tokenizer, DEVICE)
shark_module = SharkInference(mlir_module=None, device="cpu-task")
shark_module.load_module(vmfb_name)
return ModelWrapper(model=shark_module, tokenizer=tokenizer)
@@ -148,10 +94,20 @@ def run_shark_model(model_wrapper: ModelWrapper, tokens):
return model_wrapper.model("forward", tokens)
def load_huggingface_model(model_name: str) -> ModelWrapper:
def run_shark():
model_wrapper = load_shark_model()
prompt = "What is the meaning of life?"
logits = run_shark_model(model_wrapper, prompt)
# Print output logits to validate vs. pytorch + base transformers
print(logits[0])
def load_huggingface_model() -> ModelWrapper:
return ModelWrapper(
model=OPTForCausalLM.from_pretrained(model_name),
tokenizer=AutoTokenizer.from_pretrained(model_name),
model=OPTForCausalLM.from_pretrained(MODEL_NAME),
tokenizer=AutoTokenizer.from_pretrained(MODEL_NAME),
)
@@ -161,71 +117,47 @@ def run_huggingface_model(model_wrapper: ModelWrapper, tokens):
)
def run_huggingface():
model_wrapper = load_huggingface_model()
prompt = "What is the meaning of life?"
logits = run_huggingface_model(model_wrapper, prompt)
print(logits[0])
def save_json(data, filename):
with open(filename, "w") as file:
json.dump(data, file)
def collect_huggingface_logits(
model_name: str, max_seq_len: int, to_save_json: bool
) -> Tuple[float, float]:
# Load
def collect_huggingface_logits():
t0 = time.time()
model_wrapper = load_huggingface_model(model_name)
load_time = time.time() - t0
print("--- Took {} seconds to load Huggingface.".format(load_time))
load_memory_info = get_memory_info()
model_wrapper = load_huggingface_model()
print("--- Took {} seconds to load Huggingface.".format(time.time() - t0))
results = []
tokenized_prompts = []
for prompt in PROMPTS:
tokens = model_wrapper.tokenizer(
prompt,
padding="max_length",
max_length=max_seq_len,
max_length=MAX_SEQUENCE_LENGTH,
truncation=True,
return_tensors="pt",
)
tokenized_prompts.append(tokens)
# Run
t0 = time.time()
for idx, tokens in enumerate(tokenized_prompts):
print("prompt: {}".format(PROMPTS[idx]))
logits = run_huggingface_model(model_wrapper, tokens)
if to_save_json:
results.append([PROMPTS[idx], logits[0].tolist()])
run_time = time.time() - t0
print("--- Took {} seconds to run Huggingface.".format(run_time))
if to_save_json:
save_json(results, "/tmp/huggingface.json")
run_memory_info = get_memory_info()
return {
REPORT_PLATFORM: PLATFORM_HUGGINGFACE,
REPORT_MODEL_NAME: model_name,
REPORT_MAX_SEQ_LEN: max_seq_len,
REPORT_LOAD_TIME: load_time,
REPORT_RUN_TIME: run_time / len(PROMPTS),
REPORT_LOAD_PHYSICAL_MEMORY_MB: load_memory_info.rss >> 20,
REPORT_LOAD_VIRTUAL_MEMORY_MB: load_memory_info.vms >> 20,
REPORT_RUN_PHYSICAL_MEMORY_MB: run_memory_info.rss >> 20,
REPORT_RUN_VIRTUAL_MEMORY_MB: run_memory_info.vms >> 20,
}
results.append([PROMPTS[idx], logits[0].tolist()])
print("--- Took {} seconds to run Huggingface.".format(time.time() - t0))
save_json(results, "/tmp/huggingface.json")
def collect_shark_logits(
model_name: str,
max_seq_len: int,
recompile_shark: bool,
to_save_json: bool,
) -> Tuple[float, float]:
# Load
def collect_shark_logits():
t0 = time.time()
model_wrapper = load_shark_model(model_name, max_seq_len, recompile_shark)
load_time = time.time() - t0
print("--- Took {} seconds to load Shark.".format(load_time))
load_memory_info = get_memory_info()
model_wrapper = load_shark_model()
print("--- Took {} seconds to load Shark.".format(time.time() - t0))
results = []
tokenized_prompts = []
for prompt in PROMPTS:
@@ -233,7 +165,7 @@ def collect_shark_logits(
prompt,
padding="max_length",
truncation=True,
max_length=max_seq_len,
max_length=MAX_SEQUENCE_LENGTH,
return_tensors="pt",
)
inputs = (
@@ -241,100 +173,16 @@ def collect_shark_logits(
tokens["attention_mask"],
)
tokenized_prompts.append(inputs)
# Run
t0 = time.time()
for idx, tokens in enumerate(tokenized_prompts):
print("prompt: {}".format(PROMPTS[idx]))
logits = run_shark_model(model_wrapper, tokens)
lst = [e.tolist() for e in logits]
if to_save_json:
results.append([PROMPTS[idx], lst])
run_time = time.time() - t0
print("--- Took {} seconds to run Shark.".format(run_time))
if to_save_json:
save_json(results, "/tmp/shark.json")
platform_postfix = "-compile" if recompile_shark else "-precompiled"
run_memory_info = get_memory_info()
return {
REPORT_PLATFORM: PLATFORM_SHARK + platform_postfix,
REPORT_MODEL_NAME: model_name,
REPORT_MAX_SEQ_LEN: max_seq_len,
REPORT_LOAD_TIME: load_time,
REPORT_RUN_TIME: run_time / len(PROMPTS),
REPORT_LOAD_PHYSICAL_MEMORY_MB: load_memory_info.rss >> 20,
REPORT_LOAD_VIRTUAL_MEMORY_MB: load_memory_info.vms >> 20,
REPORT_RUN_PHYSICAL_MEMORY_MB: run_memory_info.rss >> 20,
REPORT_RUN_VIRTUAL_MEMORY_MB: run_memory_info.vms >> 20,
}
def get_opt_fs_name(model_name: str) -> str:
"""Cleanses the model name ino a file system-friendly name.
Example: get_opt_fs_name('facebook/opt-1.3b') == 'opt_1-3b'
"""
slash_split = model_name.split("/")
assert 1 <= len(slash_split) <= 2, "There should be at most one slash."
model_name = slash_split[-1]
for src_pattern, dest_pattern in (("-", "_"), (".", "-")):
model_name = model_name.replace(src_pattern, dest_pattern)
return model_name
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--save-json",
help="If set, saves output JSON.",
action=argparse.BooleanOptionalAction,
default=False,
)
parser.add_argument(
"--max-seq-len", help="Max sequence length", type=int, default=32
)
parser.add_argument(
"--model-name",
help="Model name",
type=str,
choices=[
"facebook/opt-125m",
"facebook/opt-350m",
"facebook/opt-1.3b",
"facebook/opt-6.7b",
],
default="facebook/opt-1.3b",
)
parser.add_argument(
"--recompile-shark",
help="If set, recompiles MLIR",
action=argparse.BooleanOptionalAction,
default=False,
)
parser.add_argument(
"--platform",
help="Either shark or huggingface",
type=str,
choices=[PLATFORM_SHARK, PLATFORM_HUGGINGFACE],
default=PLATFORM_SHARK,
)
args = parser.parse_args()
print("args={}".format(args))
return args
results.append([PROMPTS[idx], lst])
print("--- Took {} seconds to run Shark.".format(time.time() - t0))
save_json(results, "/tmp/shark.json")
if __name__ == "__main__":
args = parse_args()
if args.platform == PLATFORM_SHARK:
shark_report = collect_shark_logits(
args.model_name,
args.max_seq_len,
args.recompile_shark,
args.save_json,
)
print("# Summary: {}".format(json.dumps(shark_report)))
else:
huggingface_report = collect_huggingface_logits(
args.model_name, args.max_seq_len, args.save_json
)
print("# Summary: {}".format(json.dumps(huggingface_report)))
collect_shark_logits()
collect_huggingface_logits()

View File

@@ -1,30 +0,0 @@
"""
Script for running opt_perf_comparison.py in batch with a series of arguments.
Usage: python opt_perf_comparison_batch.py
"""
from typing import Iterable, List
import shlex
import subprocess
def make_commands() -> Iterable[List[str]]:
command = shlex.split("python opt_perf_comparison.py --no-save-json")
max_seq_lens = [32, 128, 256, 512]
model_names = ["facebook/opt-" + e for e in ["125m", "350m", "1.3b"]]
for max_seq_len in max_seq_lens:
for model_name in model_names:
yield command + [
f"--max-seq-len={max_seq_len}",
f"--model-name={model_name}",
]
def main():
for command in make_commands():
result = subprocess.run(command, check=True)
if __name__ == "__main__":
main()

View File

@@ -2,7 +2,7 @@ import os
import torch
from transformers import AutoTokenizer, OPTForCausalLM
from shark.shark_inference import SharkInference
from shark.shark_importer import import_with_fx, save_mlir
from shark.shark_importer import import_with_fx
from shark_opt_wrapper import OPTForCausalLMModel
model_name = "facebook/opt-1.3b"
@@ -25,13 +25,11 @@ inputs = (
model=model,
inputs=inputs,
is_f16=False,
)
mlir_module = save_mlir(
mlir_module,
debug=True,
model_name=model_name.split("/")[1],
frontend="torch",
mlir_dialect="linalg",
save_dir=".",
)
shark_module = SharkInference(
mlir_module,
device="cpu-sync",

View File

@@ -16,6 +16,12 @@ import subprocess as sp
import hashlib
import numpy as np
from pathlib import Path
from apps.stable_diffusion.src.models import (
model_wrappers as mw,
)
from apps.stable_diffusion.src.utils.stable_args import (
args,
)
def create_hash(file_name):
@@ -36,7 +42,7 @@ def save_torch_model(torch_model_list, local_tank_cache, import_args):
get_hf_img_cls_model,
get_fp16_model,
)
from shark.shark_importer import import_with_fx, save_mlir
from shark.shark_importer import import_with_fx
with open(torch_model_list) as csvfile:
torch_reader = csv.reader(csvfile, delimiter=",")
@@ -54,6 +60,31 @@ def save_torch_model(torch_model_list, local_tank_cache, import_args):
print("generating artifacts for: " + torch_model_name)
model = None
input = None
if model_type == "stable_diffusion":
args.use_tuned = False
args.import_mlir = True
args.local_tank_cache = local_tank_cache
precision_values = ["fp16"]
seq_lengths = [64, 77]
for precision_value in precision_values:
args.precision = precision_value
for length in seq_lengths:
model = mw.SharkifyStableDiffusionModel(
model_id=torch_model_name,
custom_weights="",
precision=precision_value,
max_len=length,
width=512,
height=512,
use_base_vae=False,
custom_vae="",
debug=True,
sharktank_dir=local_tank_cache,
generate_vmfb=False,
)
model()
continue
if model_type == "vision":
model, input, _ = get_vision_model(
torch_model_name, import_args
@@ -72,11 +103,10 @@ def save_torch_model(torch_model_list, local_tank_cache, import_args):
model, input, _ = get_hf_img_cls_model(
torch_model_name, import_args
)
elif model_type == "fp16":
model, input, _ = get_fp16_model(torch_model_name, import_args)
torch_model_name = torch_model_name.replace("/", "_")
if import_args["batch_size"] > 1:
print(
f"Batch size for this model set to {import_args['batch_size']}"
)
if import_args["batch_size"] != 1:
torch_model_dir = os.path.join(
local_tank_cache,
str(torch_model_name)
@@ -130,6 +160,133 @@ def save_torch_model(torch_model_list, local_tank_cache, import_args):
)
def save_tf_model(tf_model_list, local_tank_cache, import_args):
from tank.model_utils_tf import (
get_causal_image_model,
get_masked_lm_model,
get_causal_lm_model,
get_keras_model,
get_TFhf_model,
get_tfhf_seq2seq_model,
)
import os
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
import tensorflow as tf
visible_default = tf.config.list_physical_devices("GPU")
try:
tf.config.set_visible_devices([], "GPU")
visible_devices = tf.config.get_visible_devices()
for device in visible_devices:
assert device.device_type != "GPU"
except:
# Invalid device or cannot modify virtual devices once initialized.
pass
with open(tf_model_list) as csvfile:
tf_reader = csv.reader(csvfile, delimiter=",")
fields = next(tf_reader)
for row in tf_reader:
tf_model_name = row[0]
model_type = row[1]
model = None
input = None
print(f"Generating artifacts for model {tf_model_name}")
if model_type == "hf":
model, input, _ = get_masked_lm_model(
tf_model_name, import_args
)
elif model_type == "img":
model, input, _ = get_causal_image_model(
tf_model_name, import_args
)
elif model_type == "keras":
model, input, _ = get_keras_model(tf_model_name, import_args)
elif model_type == "TFhf":
model, input, _ = get_TFhf_model(tf_model_name, import_args)
elif model_type == "tfhf_seq2seq":
model, input, _ = get_tfhf_seq2seq_model(
tf_model_name, import_args
)
elif model_type == "hf_causallm":
model, input, _ = get_causal_lm_model(
tf_model_name, import_args
)
tf_model_name = tf_model_name.replace("/", "_")
if import_args["batch_size"] != 1:
tf_model_dir = os.path.join(
local_tank_cache,
str(tf_model_name)
+ "_tf"
+ f"_BS{str(import_args['batch_size'])}",
)
else:
tf_model_dir = os.path.join(
local_tank_cache, str(tf_model_name) + "_tf"
)
os.makedirs(tf_model_dir, exist_ok=True)
mlir_importer = SharkImporter(
model,
inputs=input,
frontend="tf",
)
mlir_importer.import_debug(
is_dynamic=False,
dir=tf_model_dir,
model_name=tf_model_name,
)
def save_tflite_model(tflite_model_list, local_tank_cache, import_args):
from shark.tflite_utils import TFLitePreprocessor
with open(tflite_model_list) as csvfile:
tflite_reader = csv.reader(csvfile, delimiter=",")
for row in tflite_reader:
print("\n")
tflite_model_name = row[0]
tflite_model_link = row[1]
print("tflite_model_name", tflite_model_name)
print("tflite_model_link", tflite_model_link)
tflite_model_name_dir = os.path.join(
local_tank_cache, str(tflite_model_name) + "_tflite"
)
os.makedirs(tflite_model_name_dir, exist_ok=True)
print(f"TMP_TFLITE_MODELNAME_DIR = {tflite_model_name_dir}")
# Preprocess to get SharkImporter input import_args
tflite_preprocessor = TFLitePreprocessor(str(tflite_model_name))
raw_model_file_path = tflite_preprocessor.get_raw_model_file()
inputs = tflite_preprocessor.get_inputs()
tflite_interpreter = tflite_preprocessor.get_interpreter()
# Use SharkImporter to get SharkInference input import_args
my_shark_importer = SharkImporter(
module=tflite_interpreter,
inputs=inputs,
frontend="tflite",
raw_model_file=raw_model_file_path,
)
my_shark_importer.import_debug(
dir=tflite_model_name_dir,
model_name=tflite_model_name,
func_name="main",
)
mlir_hash = create_hash(
os.path.join(
tflite_model_name_dir,
tflite_model_name + "_tflite" + ".mlir",
)
)
np.save(
os.path.join(tflite_model_name_dir, "hash"),
np.array(mlir_hash),
)
def check_requirements(frontend):
import importlib
@@ -138,6 +295,10 @@ def check_requirements(frontend):
tv_spec = importlib.util.find_spec("torchvision")
has_pkgs = tv_spec is not None
elif frontend in ["tensorflow", "tf"]:
tf_spec = importlib.util.find_spec("tensorflow")
has_pkgs = tf_spec is not None
return has_pkgs
@@ -156,11 +317,27 @@ def gen_shark_files(modelname, frontend, tank_dir, importer_args):
torch_model_csv = os.path.join(
os.path.dirname(__file__), "torch_model_list.csv"
)
tf_model_csv = os.path.join(
os.path.dirname(__file__), "tf_model_list.csv"
)
custom_model_csv = tempfile.NamedTemporaryFile(
dir=os.path.dirname(__file__),
delete=True,
)
if frontend == "torch":
# Create a temporary .csv with only the desired entry.
if frontend == "tf":
with open(tf_model_csv, mode="r") as src:
reader = csv.reader(src)
for row in reader:
if row[0] == modelname:
target = row
with open(custom_model_csv.name, mode="w") as trg:
writer = csv.writer(trg)
writer.writerow(["modelname", "src"])
writer.writerow(target)
save_tf_model(custom_model_csv.name, tank_dir, import_args)
elif frontend == "torch":
with open(torch_model_csv, mode="r") as src:
reader = csv.reader(src)
for row in reader:
@@ -194,6 +371,18 @@ if __name__ == "__main__":
# Please see: https://github.com/nod-ai/SHARK/blob/main/tank/torch_model_list.csv""",
# )
# parser.add_argument(
# "--tf_model_csv",
# type=lambda x: is_valid_file(x),
# default="./tank/tf_model_list.csv",
# help="Contains the file with tf model name and args.",
# )
# parser.add_argument(
# "--tflite_model_csv",
# type=lambda x: is_valid_file(x),
# default="./tank/tflite/tflite_model_list.csv",
# help="Contains the file with tf model name and args.",
# )
# parser.add_argument(
# "--ci_tank_dir",
# type=bool,
# default=False,
@@ -202,7 +391,7 @@ if __name__ == "__main__":
# old_import_args = parser.parse_import_args()
import_args = {
"batch_size": 1,
"batch_size": "1",
}
print(import_args)
home = str(Path.home())
@@ -210,5 +399,16 @@ if __name__ == "__main__":
torch_model_csv = os.path.join(
os.path.dirname(__file__), "torch_model_list.csv"
)
tf_model_csv = os.path.join(os.path.dirname(__file__), "tf_model_list.csv")
tflite_model_csv = os.path.join(
os.path.dirname(__file__), "tflite", "tflite_model_list.csv"
)
save_torch_model(
os.path.join(os.path.dirname(__file__), "torch_sd_list.csv"),
WORKDIR,
import_args,
)
save_torch_model(torch_model_csv, WORKDIR, import_args)
save_tf_model(tf_model_csv, WORKDIR, import_args)
save_tflite_model(tflite_model_csv, WORKDIR, import_args)

View File

@@ -278,7 +278,7 @@ def get_vision_model(torch_model, import_args):
int(import_args["batch_size"]), 3, *input_image_size
)
actual_out = model(test_input)
if fp16_model == True:
if fp16_model is not None:
test_input_fp16 = test_input.to(
device=torch.device("cuda"), dtype=torch.half
)

View File

@@ -145,7 +145,6 @@ class SharkModuleTester:
shark_args.shark_prefix = self.shark_tank_prefix
shark_args.local_tank_cache = self.local_tank_cache
shark_args.dispatch_benchmarks = self.benchmark_dispatches
shark_args.enable_tf32 = self.tf32
if self.benchmark_dispatches is not None:
_m = self.config["model_name"].split("/")
@@ -217,12 +216,10 @@ class SharkModuleTester:
result = shark_module(func_name, inputs)
golden_out, result = self.postprocess_outputs(golden_out, result)
if self.tf32 == True:
print(
"Validating with relaxed tolerances for TensorFloat32 calculations."
)
self.config["atol"] = 1e-01
self.config["rtol"] = 1e-02
if self.tf32 == "true":
print("Validating with relaxed tolerances.")
atol = 1e-02
rtol = 1e-03
try:
np.testing.assert_allclose(
golden_out,
@@ -257,6 +254,9 @@ class SharkModuleTester:
model_config = {
"batch_size": self.batch_size,
}
shark_args.enable_tf32 = self.tf32
if shark_args.enable_tf32 == True:
shark_module.compile()
shark_args.onnx_bench = self.onnx_bench
shark_module.shark_runner.benchmark_all_csv(
@@ -287,9 +287,6 @@ class SharkModuleTester:
repro_path = os.path.join("reproducers", self.tmp_prefix, "*")
bashCommand = f"gsutil cp -r {repro_path} gs://shark-public/builder/repro_artifacts/{self.ci_sha}/{self.tmp_prefix}/"
print(
f"Uploading reproducer {repro_path} to gs://shark-public/builder/repro_artifacts/{self.ci_sha}/{self.tmp_prefix}/"
)
process = subprocess.run(bashCommand.split())
def postprocess_outputs(self, golden_out, result):

28
tank/tf_model_list.csv Normal file
View File

@@ -0,0 +1,28 @@
model_name, model_type
albert-base-v2,hf
bert-base-uncased,hf
camembert-base,hf
dbmdz/convbert-base-turkish-cased,hf
distilbert-base-uncased,hf
google/electra-small-discriminator,hf
funnel-transformer/small,hf
microsoft/layoutlm-base-uncased,hf
google/mobilebert-uncased,hf
microsoft/mpnet-base,hf
roberta-base,hf
resnet50,keras
xlm-roberta-base,hf
microsoft/MiniLM-L12-H384-uncased,TFhf
funnel-transformer/small,hf
microsoft/mpnet-base,hf
facebook/convnext-tiny-224,img
google/vit-base-patch16-224,img
efficientnet-v2-s,keras
bert-large-uncased,hf
t5-base,tfhf_seq2seq
t5-large,tfhf_seq2seq
efficientnet_b0,keras
efficientnet_b7,keras
gpt2,hf_causallm
t5-base,tfhf_seq2seq
t5-large,tfhf_seq2seq
1 model_name model_type
2 albert-base-v2 hf
3 bert-base-uncased hf
4 camembert-base hf
5 dbmdz/convbert-base-turkish-cased hf
6 distilbert-base-uncased hf
7 google/electra-small-discriminator hf
8 funnel-transformer/small hf
9 microsoft/layoutlm-base-uncased hf
10 google/mobilebert-uncased hf
11 microsoft/mpnet-base hf
12 roberta-base hf
13 resnet50 keras
14 xlm-roberta-base hf
15 microsoft/MiniLM-L12-H384-uncased TFhf
16 funnel-transformer/small hf
17 microsoft/mpnet-base hf
18 facebook/convnext-tiny-224 img
19 google/vit-base-patch16-224 img
20 efficientnet-v2-s keras
21 bert-large-uncased hf
22 t5-base tfhf_seq2seq
23 t5-large tfhf_seq2seq
24 efficientnet_b0 keras
25 efficientnet_b7 keras
26 gpt2 hf_causallm
27 t5-base tfhf_seq2seq
28 t5-large tfhf_seq2seq

View File

@@ -5,6 +5,7 @@ microsoft/MiniLM-L12-H384-uncased,True,hf,True,linalg,False,66M,"nlp;bert-varian
bert-base-uncased,True,hf,True,linalg,False,109M,"nlp;bert-variant;transformer-encoder","12 layers; 768 hidden; 12 attention heads"
bert-base-cased,True,hf,True,linalg,False,109M,"nlp;bert-variant;transformer-encoder","12 layers; 768 hidden; 12 attention heads"
google/mobilebert-uncased,True,hf,True,linalg,False,25M,"nlp,bert-variant,transformer-encoder,mobile","24 layers, 512 hidden size, 128 embedding"
alexnet,False,vision,True,linalg,False,61M,"cnn,parallel-layers","The CNN that revolutionized computer vision (move away from hand-crafted features to neural networks),10 years old now and probably no longer used in prod."
resnet18,False,vision,True,linalg,False,11M,"cnn,image-classification,residuals,resnet-variant","1 7x7 conv2d and the rest are 3x3 conv2d"
resnet50,False,vision,True,linalg,False,23M,"cnn,image-classification,residuals,resnet-variant","Bottlenecks with only conv2d (1x1 conv -> 3x3 conv -> 1x1 conv blocks)"
resnet101,False,vision,True,linalg,False,29M,"cnn,image-classification,residuals,resnet-variant","Bottlenecks with only conv2d (1x1 conv -> 3x3 conv -> 1x1 conv blocks)"
@@ -17,9 +18,11 @@ facebook/deit-small-distilled-patch16-224,True,hf_img_cls,False,linalg,False,22M
microsoft/beit-base-patch16-224-pt22k-ft22k,True,hf_img_cls,False,linalg,False,86M,"image-classification,transformer-encoder,bert-variant,vision-transformer",N/A
nvidia/mit-b0,True,hf_img_cls,False,linalg,False,3.7M,"image-classification,transformer-encoder",SegFormer
mnasnet1_0,False,vision,True,linalg,False,-,"cnn, torchvision, mobile, architecture-search","Outperforms other mobile CNNs on Accuracy vs. Latency"
resnet50_fp16,False,vision,True,linalg,False,23M,"cnn,image-classification,residuals,resnet-variant","Bottlenecks with only conv2d (1x1 conv -> 3x3 conv -> 1x1 conv blocks)"
bert-base-uncased_fp16,True,fp16,False,linalg,False,109M,"nlp;bert-variant;transformer-encoder","12 layers; 768 hidden; 12 attention heads"
bert-large-uncased,True,hf,True,linalg,False,330M,"nlp;bert-variant;transformer-encoder","24 layers, 1024 hidden units, 16 attention heads"
bert-base-uncased,True,hf,False,stablehlo,False,109M,"nlp;bert-variant;transformer-encoder","12 layers; 768 hidden; 12 attention heads"
gpt2,True,hf_causallm,False,stablehlo,True,125M,"nlp;transformer-encoder","-"
facebook/opt-125m,True,hf,False,stablehlo,True,125M,"nlp;transformer-encoder","-"
distilgpt2,True,hf,False,stablehlo,True,88M,"nlp;transformer-encoder","-"
microsoft/deberta-v3-base,True,hf,False,stablehlo,True,88M,"nlp;transformer-encoder","-"
microsoft/deberta-v3-base,True,hf,False,stablehlo,True,88M,"nlp;transformer-encoder","-"
1 model_name use_tracing model_type dynamic mlir_type decompose param_count tags notes
5 bert-base-uncased True hf True linalg False 109M nlp;bert-variant;transformer-encoder 12 layers; 768 hidden; 12 attention heads
6 bert-base-cased True hf True linalg False 109M nlp;bert-variant;transformer-encoder 12 layers; 768 hidden; 12 attention heads
7 google/mobilebert-uncased True hf True linalg False 25M nlp,bert-variant,transformer-encoder,mobile 24 layers, 512 hidden size, 128 embedding
8 alexnet False vision True linalg False 61M cnn,parallel-layers The CNN that revolutionized computer vision (move away from hand-crafted features to neural networks),10 years old now and probably no longer used in prod.
9 resnet18 False vision True linalg False 11M cnn,image-classification,residuals,resnet-variant 1 7x7 conv2d and the rest are 3x3 conv2d
10 resnet50 False vision True linalg False 23M cnn,image-classification,residuals,resnet-variant Bottlenecks with only conv2d (1x1 conv -> 3x3 conv -> 1x1 conv blocks)
11 resnet101 False vision True linalg False 29M cnn,image-classification,residuals,resnet-variant Bottlenecks with only conv2d (1x1 conv -> 3x3 conv -> 1x1 conv blocks)
18 microsoft/beit-base-patch16-224-pt22k-ft22k True hf_img_cls False linalg False 86M image-classification,transformer-encoder,bert-variant,vision-transformer N/A
19 nvidia/mit-b0 True hf_img_cls False linalg False 3.7M image-classification,transformer-encoder SegFormer
20 mnasnet1_0 False vision True linalg False - cnn, torchvision, mobile, architecture-search Outperforms other mobile CNNs on Accuracy vs. Latency
21 resnet50_fp16 False vision True linalg False 23M cnn,image-classification,residuals,resnet-variant Bottlenecks with only conv2d (1x1 conv -> 3x3 conv -> 1x1 conv blocks)
22 bert-base-uncased_fp16 True fp16 False linalg False 109M nlp;bert-variant;transformer-encoder 12 layers; 768 hidden; 12 attention heads
23 bert-large-uncased True hf True linalg False 330M nlp;bert-variant;transformer-encoder 24 layers, 1024 hidden units, 16 attention heads
24 bert-base-uncased True hf False stablehlo False 109M nlp;bert-variant;transformer-encoder 12 layers; 768 hidden; 12 attention heads
25 gpt2 True hf_causallm False stablehlo True 125M nlp;transformer-encoder -
26 facebook/opt-125m True hf False stablehlo True 125M nlp;transformer-encoder -
27 distilgpt2 True hf False stablehlo True 88M nlp;transformer-encoder -
28 microsoft/deberta-v3-base True hf False stablehlo True 88M nlp;transformer-encoder -