Compare commits

..

17 Commits

Author SHA1 Message Date
Vivek Khandelwal
322874f7f9 Fix issue in Falcon-GPTQ 2023-11-03 11:48:36 +05:30
Ean Garvey
5001db3415 Add 7800xt to target triples explicitly. (#1928) 2023-11-01 17:11:45 -05:00
Vivek Khandelwal
71846344a2 Add sharded Falcon-GPTQ support
This commit adds the support for sharded Falcon-7b-GPTQ and
Falcon-180B-GPTQ. This commit also adds the support for 4-way
sharding of the Falcon model for the device ROCM.

Signed-Off By: Vivek Khandelwal <vivek@nod-labs.com>
2023-11-01 12:11:44 +05:30
gpetters94
72e27c96fc Add ZoeDepth (#1834)
* Add ZoeDepth

* Add einops to Studio imports.

* Specify ref for forked torch.hub repos.

* Unpin timm.

---------

Co-authored-by: Ean Garvey <87458719+monorimet@users.noreply.github.com>
Co-authored-by: Ean Garvey <garveyej@gmail.com>
2023-10-30 11:57:45 -05:00
PhaneeshB
7963abb8ec remove caching for rocm args 2023-10-29 07:07:57 +05:30
Ean Garvey
98244232dd Add smoothquant OPT to examples. (#1922) 2023-10-27 12:32:12 -05:00
PhaneeshB
679a452139 fix calls and remove unused imports for check_device_drivers 2023-10-27 10:30:40 +05:30
PhaneeshB
72c0a8abc8 remove dependency on external commands for driver installation check 2023-10-27 10:30:40 +05:30
Vivek Khandelwal
ea920f2955 Add sharded Falcon support 2023-10-26 21:53:25 +05:30
Phaneesh Barwaria
486202377a update dependency on rocm/hip info command (#1900)
* add support for rocm flags

* add rocm target flag to chat args

* rm rocm libs dependency message
2023-10-26 15:18:25 +05:30
Sungsoon Cho
0c38c33d0a Add opt_causallm_samples.py. (#1916) 2023-10-25 11:52:51 -05:00
Ean Garvey
841773fa32 Updates to opt_causallm example (#1905)
* Updates to opt_causallm example

* Fixup opt_perf_comparison.py

* Use same filenames across opt examples.
2023-10-24 10:54:39 -07:00
Stefan Kapusniak
0361db46f9 SD: Fix unet untuned opt_flags (#1912)
* correct my sloppy copy/paste for the untuned unet default compilation
flags that introduced an extra 'detach' into what should have been
'iree-global-opt-convert-1x1-filter-conv2d-to-matmul'
2023-10-24 12:47:33 -05:00
xzuyn
a012433ffd Save hiresfix info if used (#1914) 2023-10-24 12:45:10 -05:00
xzuyn
5061193da3 Move Generate, Randomize Seed, & Stop Batch to same positions as txt2img (#1915) 2023-10-24 12:44:39 -05:00
xzuyn
bff48924be LLaMa 2 Chat template fix (#1913) 2023-10-23 18:51:15 -05:00
Stefan Kapusniak
825b36cbdd Fix MLIR Textual PassPipeline Error (#1910) 2023-10-22 07:39:52 -07:00
31 changed files with 977 additions and 282 deletions

View File

@@ -70,7 +70,7 @@ class CompiledLMHeadEmbeddingLayer(torch.nn.Module):
class DecoderLayer(torch.nn.Module):
def __init__(self, decoder_layer_model):
def __init__(self, decoder_layer_model, falcon_variant):
super().__init__()
self.model = decoder_layer_model
@@ -85,9 +85,15 @@ class DecoderLayer(torch.nn.Module):
class CompiledDecoderLayer(torch.nn.Module):
def __init__(self, shark_decoder_layer_module):
def __init__(
self, layer_id, device_idx, falcon_variant, device, precision
):
super().__init__()
self.model = shark_decoder_layer_module
self.layer_id = layer_id
self.device_index = device_idx
self.falcon_variant = falcon_variant
self.device = device
self.precision = precision
def forward(
self,
@@ -99,6 +105,26 @@ class CompiledDecoderLayer(torch.nn.Module):
use_cache: bool = False,
output_attentions: bool = False,
):
import gc
torch.cuda.empty_cache()
gc.collect()
from pathlib import Path
from apps.language_models.utils import get_vmfb_from_path
self.falcon_vmfb_path = Path(
f"falcon_{self.falcon_variant}_layer_{self.layer_id}_{self.precision}_{self.device}.vmfb"
)
print("vmfb path for layer: ", self.falcon_vmfb_path)
self.model = get_vmfb_from_path(
self.falcon_vmfb_path,
self.device,
"linalg",
device_id=self.device_index,
)
if self.model is None:
raise ValueError("Layer vmfb not found")
hidden_states = hidden_states.to(torch.float32).detach().numpy()
attention_mask = attention_mask.detach().numpy()
@@ -112,6 +138,8 @@ class CompiledDecoderLayer(torch.nn.Module):
attention_mask,
),
)
del self.model
return tuple(
[
torch.tensor(new_hidden_states),
@@ -125,6 +153,314 @@ class CompiledDecoderLayer(torch.nn.Module):
)
class EightDecoderLayer(torch.nn.Module):
def __init__(self, decoder_layer_model, falcon_variant):
super().__init__()
self.model = decoder_layer_model
self.falcon_variant = falcon_variant
def forward(self, hidden_states, attention_mask):
new_pkvs = []
for layer in self.model:
outputs = layer(
hidden_states=hidden_states,
alibi=None,
attention_mask=attention_mask,
use_cache=True,
)
hidden_states = outputs[0]
new_pkvs.append(
(
outputs[-1][0],
outputs[-1][1],
)
)
if self.falcon_variant == "7b":
(
(new_pkv00, new_pkv01),
(new_pkv10, new_pkv11),
(new_pkv20, new_pkv21),
(new_pkv30, new_pkv31),
(new_pkv40, new_pkv41),
(new_pkv50, new_pkv51),
(new_pkv60, new_pkv61),
(new_pkv70, new_pkv71),
) = new_pkvs
result = (
hidden_states,
new_pkv00,
new_pkv01,
new_pkv10,
new_pkv11,
new_pkv20,
new_pkv21,
new_pkv30,
new_pkv31,
new_pkv40,
new_pkv41,
new_pkv50,
new_pkv51,
new_pkv60,
new_pkv61,
new_pkv70,
new_pkv71,
)
elif self.falcon_variant == "180b":
(
(new_pkv00, new_pkv01),
(new_pkv10, new_pkv11),
(new_pkv20, new_pkv21),
(new_pkv30, new_pkv31),
(new_pkv40, new_pkv41),
(new_pkv50, new_pkv51),
(new_pkv60, new_pkv61),
(new_pkv70, new_pkv71),
(new_pkv80, new_pkv81),
(new_pkv90, new_pkv91),
(new_pkv100, new_pkv101),
(new_pkv110, new_pkv111),
(new_pkv120, new_pkv121),
(new_pkv130, new_pkv131),
(new_pkv140, new_pkv141),
(new_pkv150, new_pkv151),
(new_pkv160, new_pkv161),
(new_pkv170, new_pkv171),
(new_pkv180, new_pkv181),
(new_pkv190, new_pkv191),
) = new_pkvs
result = (
hidden_states,
new_pkv00,
new_pkv01,
new_pkv10,
new_pkv11,
new_pkv20,
new_pkv21,
new_pkv30,
new_pkv31,
new_pkv40,
new_pkv41,
new_pkv50,
new_pkv51,
new_pkv60,
new_pkv61,
new_pkv70,
new_pkv71,
new_pkv80,
new_pkv81,
new_pkv90,
new_pkv91,
new_pkv100,
new_pkv101,
new_pkv110,
new_pkv111,
new_pkv120,
new_pkv121,
new_pkv130,
new_pkv131,
new_pkv140,
new_pkv141,
new_pkv150,
new_pkv151,
new_pkv160,
new_pkv161,
new_pkv170,
new_pkv171,
new_pkv180,
new_pkv181,
new_pkv190,
new_pkv191,
)
else:
raise ValueError(
"Unsupported Falcon variant: ", self.falcon_variant
)
return result
class CompiledEightDecoderLayer(torch.nn.Module):
def __init__(
self, layer_id, device_idx, falcon_variant, device, precision
):
super().__init__()
self.layer_id = layer_id
self.device_index = device_idx
self.falcon_variant = falcon_variant
self.device = device
self.precision = precision
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: torch.Tensor,
alibi: torch.Tensor = None,
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
head_mask: Optional[torch.Tensor] = None,
use_cache: bool = False,
output_attentions: bool = False,
):
import gc
torch.cuda.empty_cache()
gc.collect()
from pathlib import Path
from apps.language_models.utils import get_vmfb_from_path
self.falcon_vmfb_path = Path(
f"falcon_{self.falcon_variant}_layer_{self.layer_id}_{self.precision}_{self.device}.vmfb"
)
print("vmfb path for layer: ", self.falcon_vmfb_path)
self.model = get_vmfb_from_path(
self.falcon_vmfb_path,
self.device,
"linalg",
device_id=self.device_index,
)
if self.model is None:
raise ValueError("Layer vmfb not found")
hidden_states = hidden_states.to(torch.float32).detach().numpy()
attention_mask = attention_mask.detach().numpy()
if alibi is not None or layer_past is not None:
raise ValueError("Past Key Values and alibi should be None")
else:
output = self.model(
"forward",
(
hidden_states,
attention_mask,
),
)
del self.model
if self.falcon_variant == "7b":
result = (
torch.tensor(output[0]),
(
torch.tensor(output[1]),
torch.tensor(output[2]),
),
(
torch.tensor(output[3]),
torch.tensor(output[4]),
),
(
torch.tensor(output[5]),
torch.tensor(output[6]),
),
(
torch.tensor(output[7]),
torch.tensor(output[8]),
),
(
torch.tensor(output[9]),
torch.tensor(output[10]),
),
(
torch.tensor(output[11]),
torch.tensor(output[12]),
),
(
torch.tensor(output[13]),
torch.tensor(output[14]),
),
(
torch.tensor(output[15]),
torch.tensor(output[16]),
),
)
elif self.falcon_variant == "180b":
result = (
torch.tensor(output[0]),
(
torch.tensor(output[1]),
torch.tensor(output[2]),
),
(
torch.tensor(output[3]),
torch.tensor(output[4]),
),
(
torch.tensor(output[5]),
torch.tensor(output[6]),
),
(
torch.tensor(output[7]),
torch.tensor(output[8]),
),
(
torch.tensor(output[9]),
torch.tensor(output[10]),
),
(
torch.tensor(output[11]),
torch.tensor(output[12]),
),
(
torch.tensor(output[13]),
torch.tensor(output[14]),
),
(
torch.tensor(output[15]),
torch.tensor(output[16]),
),
(
torch.tensor(output[17]),
torch.tensor(output[18]),
),
(
torch.tensor(output[19]),
torch.tensor(output[20]),
),
(
torch.tensor(output[21]),
torch.tensor(output[22]),
),
(
torch.tensor(output[23]),
torch.tensor(output[24]),
),
(
torch.tensor(output[25]),
torch.tensor(output[26]),
),
(
torch.tensor(output[27]),
torch.tensor(output[28]),
),
(
torch.tensor(output[29]),
torch.tensor(output[30]),
),
(
torch.tensor(output[31]),
torch.tensor(output[32]),
),
(
torch.tensor(output[33]),
torch.tensor(output[34]),
),
(
torch.tensor(output[35]),
torch.tensor(output[36]),
),
(
torch.tensor(output[37]),
torch.tensor(output[38]),
),
(
torch.tensor(output[39]),
torch.tensor(output[40]),
),
)
else:
raise ValueError(
"Unsupported Falcon variant: ", self.falcon_variant
)
return result
class ShardedFalconModel:
def __init__(self, model, layers, word_embeddings, ln_f, lm_head):
super().__init__()

View File

@@ -7,7 +7,9 @@ from apps.language_models.src.model_wrappers.falcon_sharded_model import (
LMHeadEmbeddingLayer,
CompiledLMHeadEmbeddingLayer,
DecoderLayer,
EightDecoderLayer,
CompiledDecoderLayer,
CompiledEightDecoderLayer,
ShardedFalconModel,
)
from apps.language_models.src.pipelines.SharkLLMBase import SharkLLMBase
@@ -27,12 +29,13 @@ from transformers.generation import (
StoppingCriteriaList,
)
import copy
import time
import re
import torch
import torch_mlir
import os
import argparse
import gc
parser = argparse.ArgumentParser(
prog="falcon runner",
@@ -42,6 +45,12 @@ parser = argparse.ArgumentParser(
parser.add_argument(
"--falcon_variant_to_use", default="7b", help="7b, 40b, 180b"
)
parser.add_argument(
"--compressed",
default=False,
action=argparse.BooleanOptionalAction,
help="Do the compression of sharded layers",
)
parser.add_argument(
"--precision", "-p", default="fp16", choices=["fp32", "fp16", "int4"]
)
@@ -127,7 +136,7 @@ class ShardedFalcon(SharkLLMBase):
self.debug = debug
self.tokenizer = self.get_tokenizer()
self.src_model = self.get_src_model()
self.shark_model = self.compile()
self.shark_model = self.compile(compressed=args.compressed)
def get_tokenizer(self):
tokenizer = AutoTokenizer.from_pretrained(
@@ -150,7 +159,7 @@ class ShardedFalcon(SharkLLMBase):
quantization_config = GPTQConfig(bits=4, disable_exllama=True)
kwargs["quantization_config"] = quantization_config
kwargs["load_gptq_on_cpu"] = True
kwargs["device_map"] = "cpu" if self.device == "cpu" else "cuda:0"
kwargs["device_map"] = "cpu"
falcon_model = AutoModelForCausalLM.from_pretrained(
self.hf_model_path, **kwargs
)
@@ -158,7 +167,9 @@ class ShardedFalcon(SharkLLMBase):
falcon_model = falcon_model.to(torch.float32)
return falcon_model
def compile_layer(self, layer, falconCompileInput, layer_id):
def compile_layer(
self, layer, falconCompileInput, layer_id, device_idx=None
):
self.falcon_mlir_path = Path(
f"falcon_{args.falcon_variant_to_use}_layer_{layer_id}_{self.precision}.mlir"
)
@@ -177,10 +188,13 @@ class ShardedFalcon(SharkLLMBase):
single_file=True,
)
vmfb = get_vmfb_from_path(
self.falcon_vmfb_path, self.device, "linalg"
self.falcon_vmfb_path,
self.device,
"linalg",
device_id=device_idx,
)
if vmfb is not None:
return vmfb
return vmfb, device_idx
print(f"[DEBUG] vmfb not found at {self.falcon_vmfb_path.absolute()}")
if self.falcon_mlir_path.exists():
@@ -215,7 +229,7 @@ class ShardedFalcon(SharkLLMBase):
f16_input_mask = [False]
elif layer_id in ["ln_f", "lm_head"]:
f16_input_mask = [True]
elif type(layer_id) == int:
elif "_" in layer_id or type(layer_id) == int:
f16_input_mask = [True, False]
else:
raise ValueError("Unsupported layer: ", layer_id)
@@ -257,6 +271,7 @@ class ShardedFalcon(SharkLLMBase):
mlir_module=self.falcon_mlir_path,
device=self.device,
mlir_dialect="linalg",
device_idx=device_idx,
)
path = shark_module.save_module(
self.falcon_vmfb_path.parent.absolute(),
@@ -276,27 +291,43 @@ class ShardedFalcon(SharkLLMBase):
print("Saved falcon vmfb at ", str(path))
shark_module.load_module(path)
return shark_module
return shark_module, device_idx
def compile(self):
def compile(self, compressed=False):
sample_input_ids = torch.zeros([100], dtype=torch.int64)
sample_attention_mask = torch.zeros(
[1, 1, 100, 100], dtype=torch.float32
)
num_group_layers = 1
if "7b" in self.model_name:
num_in_features = 4544
if compressed:
num_group_layers = 8
else:
num_in_features = 14848
sample_attention_mask = sample_attention_mask.to(dtype=torch.bool)
if compressed:
num_group_layers = 20
sample_hidden_states = torch.zeros(
[1, 100, num_in_features], dtype=torch.float32
)
# Determine number of available devices
num_devices = 1
if self.device == "rocm":
import iree.runtime as ireert
haldriver = ireert.get_driver(self.device)
num_devices = len(haldriver.query_available_devices())
lm_head = LMHeadEmbeddingLayer(self.src_model.lm_head)
print("Compiling Layer lm_head")
shark_lm_head = self.compile_layer(
lm_head, [sample_hidden_states], "lm_head"
shark_lm_head, _ = self.compile_layer(
lm_head,
[sample_hidden_states],
"lm_head",
device_idx=0 % num_devices if self.device == "rocm" else None,
)
shark_lm_head = CompiledLMHeadEmbeddingLayer(shark_lm_head)
@@ -304,8 +335,11 @@ class ShardedFalcon(SharkLLMBase):
self.src_model.transformer.word_embeddings
)
print("Compiling Layer word_embeddings")
shark_word_embedding = self.compile_layer(
word_embedding, [sample_input_ids], "word_embeddings"
shark_word_embedding, _ = self.compile_layer(
word_embedding,
[sample_input_ids],
"word_embeddings",
device_idx=1 % num_devices if self.device == "rocm" else None,
)
shark_word_embedding = CompiledWordEmbeddingsLayer(
shark_word_embedding
@@ -313,20 +347,56 @@ class ShardedFalcon(SharkLLMBase):
ln_f = LNFEmbeddingLayer(self.src_model.transformer.ln_f)
print("Compiling Layer ln_f")
shark_ln_f = self.compile_layer(ln_f, [sample_hidden_states], "ln_f")
shark_ln_f, _ = self.compile_layer(
ln_f,
[sample_hidden_states],
"ln_f",
device_idx=2 % num_devices if self.device == "rocm" else None,
)
shark_ln_f = CompiledLNFEmbeddingLayer(shark_ln_f)
shark_layers = []
for i in range(len(self.src_model.transformer.h)):
print("Compiling Layer {}".format(i))
layer_i = self.src_model.transformer.h[i]
pytorch_layer_i = DecoderLayer(layer_i)
shark_module = self.compile_layer(
for i in range(
int(len(self.src_model.transformer.h) / num_group_layers)
):
device_idx = i % num_devices if self.device == "rocm" else None
layer_id = i
pytorch_class = DecoderLayer
compiled_class = CompiledDecoderLayer
if compressed:
layer_id = (
str(i * num_group_layers)
+ "_"
+ str((i + 1) * num_group_layers)
)
pytorch_class = EightDecoderLayer
compiled_class = CompiledEightDecoderLayer
print("Compiling Layer {}".format(layer_id))
if compressed:
layer_i = self.src_model.transformer.h[
i * num_group_layers : (i + 1) * num_group_layers
]
else:
layer_i = self.src_model.transformer.h[i]
pytorch_layer_i = pytorch_class(
layer_i, args.falcon_variant_to_use
)
shark_module, device_idx = self.compile_layer(
pytorch_layer_i,
[sample_hidden_states, sample_attention_mask],
i,
layer_id,
device_idx=device_idx,
)
del shark_module
shark_layer_i = compiled_class(
layer_id,
device_idx,
args.falcon_variant_to_use,
self.device,
self.precision,
)
shark_layer_i = CompiledDecoderLayer(shark_module)
shark_layers.append(shark_layer_i)
sharded_model = ShardedFalconModel(
@@ -355,9 +425,6 @@ class ShardedFalcon(SharkLLMBase):
if input_ids.shape[1] == 0:
input_ids = None
attention_mask = None
in_b = 1
else:
in_b = input_ids.shape[0]
generate_kwargs = {
"max_length": self.max_num_tokens,
@@ -387,7 +454,6 @@ class ShardedFalcon(SharkLLMBase):
) = self.src_model._prepare_model_inputs(
None, generation_config.bos_token_id, model_kwargs
)
batch_size = inputs_tensor.shape[0]
model_kwargs["output_attentions"] = generation_config.output_attentions
model_kwargs[
@@ -438,8 +504,6 @@ class ShardedFalcon(SharkLLMBase):
self.eos_token_id = eos_token_id
output_scores = generation_config.output_scores # False
output_attentions = generation_config.output_attentions # False
output_hidden_states = generation_config.output_hidden_states # False
return_dict_in_generate = (
generation_config.return_dict_in_generate # False
)
@@ -448,15 +512,6 @@ class ShardedFalcon(SharkLLMBase):
self.scores = (
() if (return_dict_in_generate and output_scores) else None
)
decoder_attentions = (
() if (return_dict_in_generate and output_attentions) else None
)
cross_attentions = (
() if (return_dict_in_generate and output_attentions) else None
)
decoder_hidden_states = (
() if (return_dict_in_generate and output_hidden_states) else None
)
# keep track of which sequences are already finished
self.unfinished_sequences = torch.ones(
@@ -465,7 +520,11 @@ class ShardedFalcon(SharkLLMBase):
all_text = prompt
start = time.time()
count = 0
for i in range(self.max_num_tokens - 1):
count = count + 1
next_token = self.generate_new_token()
new_word = self.tokenizer.decode(
next_token.cpu().numpy(),
@@ -477,6 +536,7 @@ class ShardedFalcon(SharkLLMBase):
all_text = all_text + new_word
print(f"{new_word}", end="", flush=True)
print(f"{all_text}", end="", flush=True)
# if eos_token was found in one sentence, set sentence to finished
if self.eos_token_id_tensor is not None:
@@ -492,6 +552,13 @@ class ShardedFalcon(SharkLLMBase):
):
break
end = time.time()
print(
"\n\nTime taken is {:.2f} seconds/token\n".format(
(end - start) / count
)
)
torch.cuda.empty_cache()
gc.collect()
@@ -1023,8 +1090,6 @@ if __name__ == "__main__":
precision=args.precision,
)
import gc
default_prompt_text = "Girafatron is obsessed with giraffes, the most glorious animal on the face of this Earth. Giraftron believes all other animals are irrelevant when compared to the glorious majesty of the giraffe.\nDaniel: Hello, Girafatron!\nGirafatron:"
continue_execution = True

View File

@@ -53,6 +53,7 @@ datas += collect_data_files("jsonschema_specifications")
datas += collect_data_files("cpuinfo")
datas += collect_data_files("langchain")
datas += collect_data_files("cv2")
datas += collect_data_files("einops")
datas += [
("src/utils/resources/prompts.json", "resources"),
("src/utils/resources/model_db.json", "resources"),

View File

@@ -11,12 +11,12 @@
"untuned": {
"fp16": {
"default_compilation_flags": [
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-flow-detach-elementwise-from-named-ops,iree-flow-convert-1x1-filter-conv2d-to-matmul,iree-preprocessing-convert-conv2d-to-img2col,iree-preprocessing-pad-linalg-ops{pad-size=32}))"
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-global-opt-detach-elementwise-from-named-ops,iree-global-opt-convert-1x1-filter-conv2d-to-matmul,iree-preprocessing-convert-conv2d-to-img2col,iree-preprocessing-pad-linalg-ops{pad-size=32}))"
]
},
"fp32": {
"default_compilation_flags": [
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-flow-detach-elementwise-from-named-ops,iree-flow-convert-1x1-filter-conv2d-to-matmul,iree-preprocessing-convert-conv2d-to-img2col,iree-preprocessing-pad-linalg-ops{pad-size=16}))"
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-global-opt-detach-elementwise-from-named-ops,iree-global-opt-convert-1x1-filter-conv2d-to-matmul,iree-preprocessing-convert-conv2d-to-img2col,iree-preprocessing-pad-linalg-ops{pad-size=16}))"
]
}
}
@@ -28,7 +28,7 @@
"specified_compilation_flags": {
"cuda": [],
"default_device": [
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-flow-detach-elementwise-from-named-ops,iree-flow-convert-1x1-filter-conv2d-to-matmul,iree-preprocessing-convert-conv2d-to-img2col,iree-preprocessing-pad-linalg-ops{pad-size=32},iree-linalg-ext-convert-conv2d-to-winograd))"
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-global-opt-detach-elementwise-from-named-ops,iree-global-opt-convert-1x1-filter-conv2d-to-matmul,iree-preprocessing-convert-conv2d-to-img2col,iree-preprocessing-pad-linalg-ops{pad-size=32},iree-linalg-ext-convert-conv2d-to-winograd))"
]
}
},
@@ -37,7 +37,7 @@
"specified_compilation_flags": {
"cuda": [],
"default_device": [
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-flow-detach-elementwise-from-named-ops,iree-flow-convert-1x1-filter-conv2d-to-matmul,iree-preprocessing-convert-conv2d-to-img2col,iree-preprocessing-pad-linalg-ops{pad-size=16},iree-linalg-ext-convert-conv2d-to-winograd))"
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-global-opt-detach-elementwise-from-named-ops,iree-global-opt-convert-1x1-filter-conv2d-to-matmul,iree-preprocessing-convert-conv2d-to-img2col,iree-preprocessing-pad-linalg-ops{pad-size=16},iree-linalg-ext-convert-conv2d-to-winograd))"
]
}
}
@@ -45,12 +45,12 @@
"untuned": {
"fp16": {
"default_compilation_flags": [
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-flow-detach-elementwise-from-named-ops,iree-flow-convert-1x1-filter-conv2d-to-matmul,iree-preprocessing-convert-conv2d-to-img2col,iree-preprocessing-pad-linalg-ops{pad-size=32},iree-linalg-ext-convert-conv2d-to-winograd))"
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-global-opt-detach-elementwise-from-named-ops,iree-global-opt-convert-1x1-filter-conv2d-to-matmul,iree-preprocessing-convert-conv2d-to-img2col,iree-preprocessing-pad-linalg-ops{pad-size=32},iree-linalg-ext-convert-conv2d-to-winograd))"
]
},
"fp32": {
"default_compilation_flags": [
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-flow-detach-elementwise-from-named-ops,iree-flow-convert-1x1-filter-conv2d-to-matmul,iree-preprocessing-convert-conv2d-to-img2col,iree-preprocessing-pad-linalg-ops{pad-size=16},iree-linalg-ext-convert-conv2d-to-winograd))"
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-global-opt-detach-elementwise-from-named-ops,iree-global-opt-convert-1x1-filter-conv2d-to-matmul,iree-preprocessing-convert-conv2d-to-img2col,iree-preprocessing-pad-linalg-ops{pad-size=16},iree-linalg-ext-convert-conv2d-to-winograd))"
]
}
}

View File

@@ -203,8 +203,8 @@ def dump_after_mlir(input_mlir, use_winograd):
if use_winograd:
preprocess_flag = (
"--iree-preprocessing-pass-pipeline=builtin.module"
"(func.func(iree-flow-detach-elementwise-from-named-ops,"
"iree-flow-convert-1x1-filter-conv2d-to-matmul,"
"(func.func(iree-global-opt-detach-elementwise-from-named-ops,"
"iree-global-opt-convert-1x1-filter-conv2d-to-matmul,"
"iree-preprocessing-convert-conv2d-to-img2col,"
"iree-preprocessing-pad-linalg-ops{pad-size=32},"
"iree-linalg-ext-convert-conv2d-to-winograd))"
@@ -212,8 +212,8 @@ def dump_after_mlir(input_mlir, use_winograd):
else:
preprocess_flag = (
"--iree-preprocessing-pass-pipeline=builtin.module"
"(func.func(iree-flow-detach-elementwise-from-named-ops,"
"iree-flow-convert-1x1-filter-conv2d-to-matmul,"
"(func.func(iree-global-opt-detach-elementwise-from-named-ops,"
"iree-global-opt-convert-1x1-filter-conv2d-to-matmul,"
"iree-preprocessing-convert-conv2d-to-img2col,"
"iree-preprocessing-pad-linalg-ops{pad-size=32}))"
)

View File

@@ -422,7 +422,7 @@ p.add_argument(
p.add_argument(
"--use_stencil",
choices=["canny", "openpose", "scribble"],
choices=["canny", "openpose", "scribble", "zoedepth"],
help="Enable the stencil feature.",
)
@@ -725,6 +725,17 @@ p.add_argument(
help="Specifies whether the docuchat's web version is running or not.",
)
##############################################################################
# rocm Flags
##############################################################################
p.add_argument(
"--iree_rocm_target_chip",
type=str,
default="gfx1100",
help="Add the rocm device architecture ex gfx1100, gfx90a, etc. Default gfx1100",
)
args, unknown = p.parse_known_args()
if args.import_debug:
os.environ["IREE_SAVE_TEMPS"] = os.path.join(

View File

@@ -1,2 +1,3 @@
from apps.stable_diffusion.src.utils.stencils.canny import CannyDetector
from apps.stable_diffusion.src.utils.stencils.openpose import OpenposeDetector
from apps.stable_diffusion.src.utils.stencils.zoe import ZoeDetector

View File

@@ -4,6 +4,7 @@ import torch
from apps.stable_diffusion.src.utils.stencils import (
CannyDetector,
OpenposeDetector,
ZoeDetector,
)
stencil = {}
@@ -117,6 +118,9 @@ def controlnet_hint_conversion(
case "scribble":
print("Working with scribble")
controlnet_hint = hint_scribble(image)
case "zoedepth":
print("Working with ZoeDepth")
controlnet_hint = hint_zoedepth(image)
case _:
return None
controlnet_hint = controlnet_hint_shaping(
@@ -127,7 +131,7 @@ def controlnet_hint_conversion(
stencil_to_model_id_map = {
"canny": "lllyasviel/control_v11p_sd15_canny",
"depth": "lllyasviel/control_v11p_sd15_depth",
"zoedepth": "lllyasviel/control_v11f1p_sd15_depth",
"hed": "lllyasviel/sd-controlnet-hed",
"mlsd": "lllyasviel/control_v11p_sd15_mlsd",
"normal": "lllyasviel/control_v11p_sd15_normalbae",
@@ -184,3 +188,16 @@ def hint_scribble(image: Image.Image):
detected_map = np.zeros_like(input_image, dtype=np.uint8)
detected_map[np.min(input_image, axis=2) < 127] = 255
return detected_map
# Stencil 4. Depth (Only Zoe Preprocessing)
def hint_zoedepth(image: Image.Image):
with torch.no_grad():
input_image = np.array(image)
if not "depth" in stencil:
stencil["depth"] = ZoeDetector()
detected_map = stencil["depth"](input_image)
detected_map = HWC3(detected_map)
return detected_map

View File

@@ -0,0 +1,58 @@
import numpy as np
import torch
from pathlib import Path
import requests
from einops import rearrange
remote_model_path = (
"https://huggingface.co/lllyasviel/Annotators/resolve/main/ZoeD_M12_N.pt"
)
class ZoeDetector:
def __init__(self):
cwd = Path.cwd()
ckpt_path = Path(cwd, "stencil_annotator")
ckpt_path.mkdir(parents=True, exist_ok=True)
modelpath = ckpt_path / "ZoeD_M12_N.pt"
with requests.get(remote_model_path, stream=True) as r:
r.raise_for_status()
with open(modelpath, "wb") as f:
for chunk in r.iter_content(chunk_size=8192):
f.write(chunk)
model = torch.hub.load(
"monorimet/ZoeDepth:torch_update",
"ZoeD_N",
pretrained=False,
force_reload=False,
)
model.load_state_dict(
torch.load(modelpath, map_location=model.device)["model"]
)
model.eval()
self.model = model
def __call__(self, input_image):
assert input_image.ndim == 3
image_depth = input_image
with torch.no_grad():
image_depth = torch.from_numpy(image_depth).float()
image_depth = image_depth / 255.0
image_depth = rearrange(image_depth, "h w c -> 1 c h w")
depth = self.model.infer(image_depth)
depth = depth[0, 0].cpu().numpy()
vmin = np.percentile(depth, 2)
vmax = np.percentile(depth, 85)
depth -= vmin
depth /= vmax - vmin
depth = 1.0 - depth
depth_image = (depth * 255.0).clip(0, 255).astype(np.uint8)
return depth_image

View File

@@ -895,6 +895,13 @@ def save_output_img(output_img, img_seed, extra_info=None):
pngInfo = PngImagePlugin.PngInfo()
if args.write_metadata_to_png:
# Using a conditional expression caused problems, so setting a new
# variable for now.
if args.use_hiresfix:
png_size_text = f"{args.hiresfix_width}x{args.hiresfix_height}"
else:
png_size_text = f"{args.width}x{args.height}"
pngInfo.add_text(
"parameters",
f"{args.prompts[0]}"
@@ -903,7 +910,7 @@ def save_output_img(output_img, img_seed, extra_info=None):
f"Sampler: {args.scheduler}, "
f"CFG scale: {args.guidance_scale}, "
f"Seed: {img_seed},"
f"Size: {args.width}x{args.height}, "
f"Size: {png_size_text}, "
f"Model: {img_model}, "
f"VAE: {img_vae}, "
f"LoRA: {img_lora}",
@@ -930,8 +937,10 @@ def save_output_img(output_img, img_seed, extra_info=None):
"CFG_SCALE": args.guidance_scale,
"PRECISION": args.precision,
"STEPS": args.steps,
"HEIGHT": args.height,
"WIDTH": args.width,
"HEIGHT": args.height
if not args.use_hiresfix
else args.hiresfix_height,
"WIDTH": args.width if not args.use_hiresfix else args.hiresfix_width,
"MAX_LENGTH": args.max_length,
"OUTPUT": out_img_path,
"VAE": img_vae,
@@ -969,6 +978,10 @@ def get_generation_text_info(seeds, device):
)
text_output += (
f"\nsize={args.height}x{args.width}, "
if not args.use_hiresfix
else f"\nsize={args.hiresfix_height}x{args.hiresfix_width}, "
)
text_output += (
f"batch_count={args.batch_count}, "
f"batch_size={args.batch_size}, "
f"max_length={args.max_length}"

View File

@@ -435,8 +435,13 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
elem_id="stencil_model",
label="Stencil model",
value="None",
choices=["None", "canny", "openpose", "scribble"],
allow_custom_value=True,
choices=[
"None",
"canny",
"openpose",
"scribble",
"zoedepth",
],
)
def show_canvas(choice):
@@ -638,16 +643,6 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
choices=available_devices,
allow_custom_value=True,
)
with gr.Row():
random_seed = gr.Button("Randomize Seed")
random_seed.click(
lambda: -1,
inputs=[],
outputs=[seed],
queue=False,
)
stop_batch = gr.Button("Stop Batch")
stable_diffusion = gr.Button("Generate Image(s)")
with gr.Column(scale=1, min_width=600):
with gr.Group():
@@ -667,6 +662,18 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
show_label=False,
)
img2img_status = gr.Textbox(visible=False)
with gr.Row():
stable_diffusion = gr.Button("Generate Image(s)")
random_seed = gr.Button("Randomize Seed")
random_seed.click(
lambda: -1,
inputs=[],
outputs=[seed],
queue=False,
)
stop_batch = gr.Button("Stop Batch")
with gr.Row():
blank_thing_for_row = None
with gr.Row():
img2img_sendto_inpaint = gr.Button(value="SendTo Inpaint")
img2img_sendto_outpaint = gr.Button(

View File

@@ -514,16 +514,6 @@ with gr.Blocks(title="Inpainting") as inpaint_web:
choices=available_devices,
allow_custom_value=True,
)
with gr.Row():
random_seed = gr.Button("Randomize Seed")
random_seed.click(
lambda: -1,
inputs=[],
outputs=[seed],
queue=False,
)
stop_batch = gr.Button("Stop Batch")
stable_diffusion = gr.Button("Generate Image(s)")
with gr.Column(scale=1, min_width=600):
with gr.Group():
@@ -543,7 +533,18 @@ with gr.Blocks(title="Inpainting") as inpaint_web:
show_label=False,
)
inpaint_status = gr.Textbox(visible=False)
with gr.Row():
stable_diffusion = gr.Button("Generate Image(s)")
random_seed = gr.Button("Randomize Seed")
random_seed.click(
lambda: -1,
inputs=[],
outputs=[seed],
queue=False,
)
stop_batch = gr.Button("Stop Batch")
with gr.Row():
blank_thing_for_row = None
with gr.Row():
inpaint_sendto_img2img = gr.Button(value="SendTo Img2Img")
inpaint_sendto_outpaint = gr.Button(

View File

@@ -540,16 +540,6 @@ with gr.Blocks(title="Outpainting") as outpaint_web:
choices=available_devices,
allow_custom_value=True,
)
with gr.Row():
random_seed = gr.Button("Randomize Seed")
random_seed.click(
lambda: -1,
inputs=[],
outputs=[seed],
queue=False,
)
stop_batch = gr.Button("Stop Batch")
stable_diffusion = gr.Button("Generate Image(s)")
with gr.Column(scale=1, min_width=600):
with gr.Group():
@@ -569,6 +559,18 @@ with gr.Blocks(title="Outpainting") as outpaint_web:
show_label=False,
)
outpaint_status = gr.Textbox(visible=False)
with gr.Row():
stable_diffusion = gr.Button("Generate Image(s)")
random_seed = gr.Button("Randomize Seed")
random_seed.click(
lambda: -1,
inputs=[],
outputs=[seed],
queue=False,
)
stop_batch = gr.Button("Stop Batch")
with gr.Row():
blank_thing_for_row = None
with gr.Row():
outpaint_sendto_img2img = gr.Button(value="SendTo Img2Img")
outpaint_sendto_inpaint = gr.Button(value="SendTo Inpaint")

View File

@@ -32,36 +32,39 @@ model_map = {
# NOTE: Each `model_name` should have its own start message
start_message = {
"llama2_7b": (
"System: You are a helpful, respectful and honest assistant. Always answer "
"as helpfully as possible, while being safe. Your answers should not "
"include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal "
"content. Please ensure that your responses are socially unbiased and positive "
"in nature. If a question does not make any sense, or is not factually coherent, "
"explain why instead of answering something not correct. If you don't know the "
"answer to a question, please don't share false information."
"You are a helpful, respectful and honest assistant. Always answer "
"as helpfully as possible, while being safe. Your answers should not "
"include any harmful, unethical, racist, sexist, toxic, dangerous, or "
"illegal content. Please ensure that your responses are socially "
"unbiased and positive in nature. If a question does not make any "
"sense, or is not factually coherent, explain why instead of "
"answering something not correct. If you don't know the answer "
"to a question, please don't share false information."
),
"llama2_13b": (
"System: You are a helpful, respectful and honest assistant. Always answer "
"as helpfully as possible, while being safe. Your answers should not "
"include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal "
"content. Please ensure that your responses are socially unbiased and positive "
"in nature. If a question does not make any sense, or is not factually coherent, "
"explain why instead of answering something not correct. If you don't know the "
"answer to a question, please don't share false information."
"You are a helpful, respectful and honest assistant. Always answer "
"as helpfully as possible, while being safe. Your answers should not "
"include any harmful, unethical, racist, sexist, toxic, dangerous, or "
"illegal content. Please ensure that your responses are socially "
"unbiased and positive in nature. If a question does not make any "
"sense, or is not factually coherent, explain why instead of "
"answering something not correct. If you don't know the answer "
"to a question, please don't share false information."
),
"llama2_70b": (
"System: You are a helpful, respectful and honest assistant. Always answer "
"as helpfully as possible, while being safe. Your answers should not "
"include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal "
"content. Please ensure that your responses are socially unbiased and positive "
"in nature. If a question does not make any sense, or is not factually coherent, "
"explain why instead of answering something not correct. If you don't know the "
"answer to a question, please don't share false information."
"You are a helpful, respectful and honest assistant. Always answer "
"as helpfully as possible, while being safe. Your answers should not "
"include any harmful, unethical, racist, sexist, toxic, dangerous, or "
"illegal content. Please ensure that your responses are socially "
"unbiased and positive in nature. If a question does not make any "
"sense, or is not factually coherent, explain why instead of "
"answering something not correct. If you don't know the answer "
"to a question, please don't share false information."
),
"vicuna": (
"A chat between a curious user and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the user's "
"questions.\n"
"A chat between a curious user and an artificial intelligence "
"assistant. The assistant gives helpful, detailed, and "
"polite answers to the user's questions.\n"
),
}
@@ -77,7 +80,10 @@ def create_prompt(model_name, history, prompt_prefix):
conversation = "".join(
[f"{B_INST} {item[0]} {E_INST} {item[1]} " for item in history[1:]]
)
msg = f"{B_INST} {B_SYS} {system_message} {E_SYS} {history[0][0]} {E_INST} {history[0][1]} {conversation}"
if prompt_prefix:
msg = f"{B_INST} {B_SYS}{system_message}{E_SYS}{history[0][0]} {E_INST} {history[0][1]} {conversation}"
else:
msg = f"{B_INST} {history[0][0]} {E_INST} {history[0][1]} {conversation}"
elif model_name in ["vicuna"]:
conversation = "".join(
[
@@ -210,8 +216,14 @@ def chat(
assert (
device_id
), f"no vulkan hardware for target-triple '{vulkan_target_triple}' exists"
print(f"Will use vulkan target triple : {vulkan_target_triple}")
print(f"Will use target triple : {vulkan_target_triple}")
elif "rocm" in device:
# add iree rocm flags
_extra_args.append(
f"--iree-rocm-target-chip={args.iree_rocm_target_chip}"
)
print(f"extra args = {_extra_args}")
if model_name == "vicuna4":
vicuna_model = ShardedVicuna(

View File

@@ -140,6 +140,11 @@ def txt2img_inf(
args.max_length = max_length
args.height = height
args.width = width
args.use_hiresfix = use_hiresfix
args.hiresfix_height = hiresfix_height
args.hiresfix_width = hiresfix_width
args.hiresfix_strength = hiresfix_strength
args.resample_type = resample_type
args.device = device.split("=>", 1)[1].strip()
args.iree_vulkan_target_triple = init_iree_vulkan_target_triple
args.iree_metal_target_platform = init_iree_metal_target_platform

View File

@@ -533,16 +533,6 @@ with gr.Blocks(title="Upscaler") as upscaler_web:
choices=available_devices,
allow_custom_value=True,
)
with gr.Row():
random_seed = gr.Button("Randomize Seed")
random_seed.click(
lambda: -1,
inputs=[],
outputs=[seed],
queue=False,
)
stop_batch = gr.Button("Stop Batch")
stable_diffusion = gr.Button("Generate Image(s)")
with gr.Column(scale=1, min_width=600):
with gr.Group():
@@ -562,7 +552,18 @@ with gr.Blocks(title="Upscaler") as upscaler_web:
show_label=False,
)
upscaler_status = gr.Textbox(visible=False)
with gr.Row():
stable_diffusion = gr.Button("Generate Image(s)")
random_seed = gr.Button("Randomize Seed")
random_seed.click(
lambda: -1,
inputs=[],
outputs=[seed],
queue=False,
)
stop_batch = gr.Button("Stop Batch")
with gr.Row():
blank_thing_for_row = None
with gr.Row():
upscaler_sendto_img2img = gr.Button(value="SendTo Img2Img")
upscaler_sendto_inpaint = gr.Button(value="SendTo Inpaint")

View File

@@ -129,12 +129,12 @@ pytest_benchmark_param = pytest.mark.parametrize(
pytest.param(True, "cpu", marks=pytest.mark.skip),
pytest.param(
False,
"gpu",
"cuda",
marks=pytest.mark.skipif(
check_device_drivers("gpu"), reason="nvidia-smi not found"
check_device_drivers("cuda"), reason="nvidia-smi not found"
),
),
pytest.param(True, "gpu", marks=pytest.mark.skip),
pytest.param(True, "cuda", marks=pytest.mark.skip),
pytest.param(
False,
"vulkan",

View File

@@ -41,6 +41,7 @@ tiktoken # for codegen
joblib # for langchain
timm # for MiniGPT4
langchain
einops # for zoedepth
# Keep PyInstaller at the end. Sometimes Windows Defender flags it but most folks can continue even if it errors
pefile

View File

@@ -95,38 +95,31 @@ _IREE_TARGET_MAP = {
# Finds whether the required drivers are installed for the given device.
@functools.cache
def check_device_drivers(device):
"""Checks necessary drivers present for gpu and vulkan devices"""
"""
Checks necessary drivers present for gpu and vulkan devices
False => drivers present!
"""
if "://" in device:
device = device.split("://")[0]
if device == "cuda":
try:
subprocess.check_output("nvidia-smi")
except Exception:
return True
elif device in ["vulkan"]:
try:
subprocess.check_output("vulkaninfo")
except Exception:
return True
elif device == "metal":
return False
elif device in ["intel-gpu"]:
try:
subprocess.check_output(["dpkg", "-L", "intel-level-zero-gpu"])
return False
except Exception:
return True
elif device == "cpu":
return False
elif device == "rocm":
try:
if sys.platform == "win32":
subprocess.check_output("hipinfo")
else:
subprocess.check_output("rocminfo")
except Exception:
return True
from iree.runtime import get_driver
device_mapped = iree_device_map(device)
try:
_ = get_driver(device_mapped)
except ValueError as ve:
print(
f"[ERR] device `{device}` not registered with IREE. "
"Ensure IREE is configured for use with this device.\n"
f"Full Error: \n {repr(ve)}"
)
return True
except RuntimeError as re:
print(
f"[ERR] Failed to get driver for {device} with error:\n{repr(re)}"
)
return True
# Unknown device. We assume drivers are installed.
return False
@@ -134,11 +127,32 @@ def check_device_drivers(device):
# Installation info for the missing device drivers.
def device_driver_info(device):
if device == "cuda":
return "nvidia-smi not found, please install the required drivers from https://www.nvidia.in/Download/index.aspx?lang=en-in"
elif device in ["metal", "vulkan"]:
return "vulkaninfo not found, Install from https://vulkan.lunarg.com/sdk/home or your distribution"
elif device == "rocm":
return "rocm info not found. Please install rocm"
device_driver_err_map = {
"cuda": {
"debug": "Try `nvidia-smi` on system to check.",
"solution": " from https://www.nvidia.in/Download/index.aspx?lang=en-in for your system.",
},
"vulkan": {
"debug": "Try `vulkaninfo` on system to check.",
"solution": " from https://vulkan.lunarg.com/sdk/home for your distribution.",
},
"metal": {
"debug": "Check if Bare metal is supported and enabled on your system.",
"solution": ".",
},
"rocm": {
"debug": f"Try `{'hip' if sys.platform == 'win32' else 'rocm'}info` on system to check.",
"solution": " from https://rocm.docs.amd.com/en/latest/rocm.html for your system.",
},
}
if device in device_driver_err_map:
err_msg = (
f"Required drivers for {device} not found. {device_driver_err_map[device]['debug']} "
f"Please install the required drivers{device_driver_err_map[device]['solution']} "
f"For further assistance please reach out to the community on discord [https://discord.com/invite/RUqY2h2s9u]"
f" and/or file a bug at https://github.com/nod-ai/SHARK/issues"
)
return err_msg
else:
return f"{device} is not supported."

View File

@@ -75,7 +75,7 @@ def get_iree_device_args(device, extra_args=[]):
if device_uri[0] == "rocm":
from shark.iree_utils.gpu_utils import get_iree_rocm_args
return get_iree_rocm_args()
return get_iree_rocm_args(extra_args=extra_args)
return []
@@ -392,6 +392,9 @@ def load_vmfb_using_mmap(
)
dl.log(f"ireert.create_device()")
config = ireert.Config(device=haldevice)
config.id = haldriver.query_available_devices()[device_idx][
"device_id"
]
dl.log(f"ireert.Config()")
else:
config = get_iree_runtime_config(device)
@@ -574,10 +577,17 @@ def get_results(
frontend="torch",
send_to_host=True,
debug_timeout: float = 5.0,
device: str = None,
):
"""Runs a .vmfb file given inputs and config and returns output."""
with DetailLogger(debug_timeout) as dl:
device_inputs = []
if device == "rocm":
haldriver = ireert.get_driver("rocm")
haldevice = haldriver.create_device(
config.id,
allocators=shark_args.device_allocator,
)
for input_array in input:
dl.log(f"Load to device: {input_array.shape}")
device_inputs.append(

View File

@@ -40,55 +40,26 @@ def get_iree_gpu_args():
# Get the default gpu args given the architecture.
@functools.cache
def get_iree_rocm_args():
def get_iree_rocm_args(extra_args=[]):
ireert.flags.FUNCTION_INPUT_VALIDATION = False
# get arch from hipinfo.
import os
import re
import subprocess
rocm_flags = ["--iree-rocm-link-bc=true"]
if sys.platform == "win32":
if "HIP_PATH" in os.environ:
rocm_path = os.environ["HIP_PATH"]
print(f"Found a ROCm installation at {rocm_path}.")
else:
print("Failed to find ROCM_PATH. Defaulting to C:\\AMD\\ROCM\\5.5")
rocm_path = "C:\\AMD\\ROCM\\5.5"
else:
if "ROCM_PATH" in os.environ:
rocm_path = os.environ["ROCM_PATH"]
print(f"Found a ROCm installation at {rocm_path}.")
else:
print("Failed to find ROCM_PATH. Defaulting to /opt/rocm")
rocm_path = "/opt/rocm/"
try:
if sys.platform == "win32":
rocm_arch = re.search(
r"gfx\d{3,}",
subprocess.check_output("hipinfo", shell=True, text=True),
).group(0)
else:
rocm_arch = re.match(
r".*(gfx\w+)",
subprocess.check_output(
"rocminfo | grep -i 'gfx'", shell=True, text=True
),
).group(1)
print(f"Found rocm arch {rocm_arch}...")
except:
# Add the target arch flag for rocm device
flag_present = False
for flag in extra_args:
if "iree-rocm-target-chip" in flag:
flag_present = True
print(
f"found rocm target device arch from flag : {flag.split('=')[1]}"
)
if not flag_present:
print(
"Failed to find ROCm architecture from hipinfo / rocminfo. Defaulting to gfx1100."
)
rocm_arch = "gfx1100"
rocm_flags.append(f"--iree-rocm-target-chip={rocm_arch}")
bc_path = os.path.join(rocm_path, "amdgcn", "bitcode")
return [
f"--iree-rocm-target-chip={rocm_arch}",
"--iree-rocm-link-bc=true",
f"--iree-rocm-bc-dir={bc_path}",
]
return rocm_flags
# Some constants taken from cuda.h

View File

@@ -68,6 +68,8 @@ def get_vulkan_target_triple(device_name):
Returns:
str or None: target triple or None if no match found for given name
"""
# TODO: Replace this with a dict or something smarter.
system_os = get_os_name()
# Apple Targets
if all(x in device_name for x in ("Apple", "M1")):
@@ -117,6 +119,8 @@ def get_vulkan_target_triple(device_name):
# Amd Targets
# Linux: Radeon RX 7900 XTX
# Windows: AMD Radeon RX 7900 XTX
elif all(x in device_name for x in ("RX", "7800")):
triple = f"rdna3-7800-{system_os}"
elif all(x in device_name for x in ("RX", "7900")):
triple = f"rdna3-7900-{system_os}"
elif all(x in device_name for x in ("Radeon", "780M")):

View File

@@ -150,11 +150,15 @@ class SharkInference:
# inputs are considered to be tuple of np.array.
def __call__(self, function_name: str, inputs: tuple, send_to_host=True):
return self.shark_runner.run(function_name, inputs, send_to_host)
return self.shark_runner.run(
function_name, inputs, send_to_host, device=self.device
)
# forward function.
def forward(self, inputs: tuple, send_to_host=True):
return self.shark_runner.run("forward", inputs, send_to_host)
return self.shark_runner.run(
"forward", inputs, send_to_host, device=self.device
)
# Get all function names defined within the compiled module.
def get_functions_in_module(self):

View File

@@ -109,7 +109,9 @@ class SharkRunner:
self.temp_file_to_unlink = params["temp_file_to_unlink"]
del params
def run(self, function_name, inputs: tuple, send_to_host=False):
def run(
self, function_name, inputs: tuple, send_to_host=False, device=None
):
return get_results(
self.iree_compilation_module,
function_name,
@@ -117,6 +119,7 @@ class SharkRunner:
self.iree_config,
self.mlir_dialect,
send_to_host,
device=device,
)
# Get all function names defined within the compiled module.

View File

@@ -1,4 +1,3 @@
from shark.iree_utils._common import check_device_drivers, device_driver_info
from shark.shark_inference import SharkInference
from shark.shark_downloader import download_model
from shark.parser import shark_args

View File

@@ -1,30 +1,25 @@
import argparse
import os
import torch
import numpy as np
from shark_opt_wrapper import OPTForCausalLMModel
from shark.iree_utils._common import (
check_device_drivers,
device_driver_info,
)
from shark.shark_inference import SharkInference
from shark.shark_importer import import_with_fx
from transformers import AutoTokenizer, OPTForCausalLM
OPT_MODEL = "opt-1.3b"
OPT_FS_NAME = "opt-1_3b"
MAX_SEQUENCE_LENGTH = 128
MAX_NEW_TOKENS = 60
from typing import Iterable
def create_module(model_name, tokenizer, device):
opt_base_model = OPTForCausalLM.from_pretrained("facebook/" + model_name)
def create_module(model_name, tokenizer, device, args):
opt_base_model = OPTForCausalLM.from_pretrained(
model_name, allow_mismatched_sizes=True
)
opt_base_model.eval()
opt_model = OPTForCausalLMModel(opt_base_model)
encoded_inputs = tokenizer(
"What is the meaning of life?",
padding="max_length",
truncation=True,
max_length=MAX_SEQUENCE_LENGTH,
max_length=args.max_seq_len,
return_tensors="pt",
)
inputs = (
@@ -33,8 +28,11 @@ def create_module(model_name, tokenizer, device):
)
# np.save("model_inputs_0.npy", inputs[0])
# np.save("model_inputs_1.npy", inputs[1])
opt_fs_name = "-".join(
"_".join(args.model_name.split("/")[1].split("-")).split(".")
)
mlir_path = f"./{OPT_FS_NAME}_causallm_{MAX_SEQUENCE_LENGTH}_torch.mlir"
mlir_path = f"./{opt_fs_name}_causallm_{args.max_seq_len}_torch.mlir"
if os.path.isfile(mlir_path):
print(f"Found .mlir from {mlir_path}")
else:
@@ -42,7 +40,7 @@ def create_module(model_name, tokenizer, device):
model=opt_model,
inputs=inputs,
is_f16=False,
model_name=OPT_FS_NAME,
model_name=opt_fs_name,
return_str=True,
)
with open(mlir_path, "w") as f:
@@ -57,7 +55,7 @@ def create_module(model_name, tokenizer, device):
is_benchmark=False,
)
vmfb_name = f"{OPT_FS_NAME}_causallm_{MAX_SEQUENCE_LENGTH}_torch_{device}"
vmfb_name = f"{opt_fs_name}_causallm_{args.max_seq_len}_torch_cpu"
shark_module.save_module(module_name=vmfb_name, debug=False)
vmfb_path = vmfb_name + ".vmfb"
return vmfb_path
@@ -71,11 +69,11 @@ def shouldStop(tokens):
return False
def generate_new_token(shark_model, tokenizer, new_text):
def generate_new_token(shark_module, tokenizer, new_text, max_seq_len: int):
model_inputs = tokenizer(
new_text,
padding="max_length",
max_length=MAX_SEQUENCE_LENGTH,
max_length=max_seq_len,
truncation=True,
return_tensors="pt",
)
@@ -84,7 +82,7 @@ def generate_new_token(shark_model, tokenizer, new_text):
model_inputs["attention_mask"],
)
sum_attentionmask = torch.sum(model_inputs.attention_mask)
output = shark_model("forward", inputs)
output = shark_module("forward", inputs)
output = torch.FloatTensor(output[0])
next_toks = torch.topk(output, 1)
stop_generation = False
@@ -104,39 +102,96 @@ def generate_new_token(shark_model, tokenizer, new_text):
return ret_dict
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--max-seq-len", type=int, default=32)
parser.add_argument(
"--model-name",
help="Model name",
type=str,
choices=[
"facebook/opt-125m",
"facebook/opt-350m",
"facebook/opt-1.3b",
"facebook/opt-6.7b",
"mit-han-lab/opt-125m-smoothquant",
"mit-han-lab/opt-1.3b-smoothquant",
"mit-han-lab/opt-2.7b-smoothquant",
"mit-han-lab/opt-6.7b-smoothquant",
"mit-han-lab/opt-13b-smoothquant",
],
default="facebook/opt-1.3b",
)
parser.add_argument(
"--recompile",
help="If set, recompiles MLIR -> .vmfb",
action=argparse.BooleanOptionalAction,
default=False,
)
parser.add_argument(
"--plugin-path",
help="path to executable plugin",
type=str,
default=None,
)
args = parser.parse_args()
print("args={}".format(args))
return args
def generate_tokens(
opt_shark_module: "SharkInference",
tokenizer,
input_text: str,
max_output_len: int,
print_intermediate_results: True,
) -> Iterable[str]:
words_list = []
new_text = input_text
try:
for _ in range(max_output_len):
generated_token_op = generate_new_token(
opt_shark_module, tokenizer, new_text, max_output_len
)
detok = generated_token_op["detok"]
if generated_token_op["stop_generation"]:
break
if print_intermediate_results:
print(detok, end="", flush=True)
words_list.append(detok)
if detok == "":
break
new_text += detok
except KeyboardInterrupt as e:
print("Exiting token generation.")
return words_list
if __name__ == "__main__":
tokenizer = AutoTokenizer.from_pretrained(
"facebook/" + OPT_MODEL, use_fast=False
args = parse_args()
if "smoothquant" in args.model_name:
token_model_name = f"facebook/opt-{args.model_name.split('-')[3]}"
else:
token_model_name = args.model_name
tokenizer = AutoTokenizer.from_pretrained(token_model_name, use_fast=False)
opt_fs_name = "-".join(
"_".join(args.model_name.split("/")[1].split("-")).split(".")
)
vmfb_path = (
f"./{OPT_FS_NAME}_causallm_{MAX_SEQUENCE_LENGTH}_torch_cpu-task.vmfb"
vmfb_path = f"./{opt_fs_name}_causallm_{args.max_seq_len}_torch_cpu.vmfb"
if args.plugin_path is not None:
rt_flags = [f"--executable_plugin={args.plugin_path}"]
else:
rt_flags = []
opt_shark_module = SharkInference(
mlir_module=None, device="cpu-task", rt_flags=rt_flags
)
opt_shark_module = SharkInference(mlir_module=None, device="cpu-task")
if os.path.isfile(vmfb_path):
opt_shark_module.load_module(vmfb_path)
else:
vmfb_path = create_module(OPT_MODEL, tokenizer, "cpu-task")
vmfb_path = create_module(args.model_name, tokenizer, "cpu-task", args)
opt_shark_module.load_module(vmfb_path)
while True:
try:
new_text = input("Give me a sentence to complete:")
new_text_init = new_text
words_list = []
for i in range(MAX_NEW_TOKENS):
generated_token_op = generate_new_token(
opt_shark_module, tokenizer, new_text
)
detok = generated_token_op["detok"]
stop_generation = generated_token_op["stop_generation"]
if stop_generation:
break
print(detok, end="", flush=True)
words_list.append(detok)
if detok == "":
break
new_text = new_text + detok
except KeyboardInterrupt:
print("Exiting program.")
break
input_text = input("Give me a sentence to complete:")
generate_tokens(
opt_shark_module, tokenizer, input_text, args.max_seq_len
)

View File

@@ -0,0 +1,74 @@
import argparse
import os
import opt_causallm
import opt_util
from shark.shark_inference import SharkInference
from transformers import AutoTokenizer, OPTForCausalLM
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--max-seq-len", type=int, default=32)
parser.add_argument(
"--model-name",
help="Model name",
type=str,
choices=[
"facebook/opt-125m",
"facebook/opt-350m",
"facebook/opt-1.3b",
"facebook/opt-6.7b",
],
default="facebook/opt-1.3b",
)
parser.add_argument(
"--recompile",
help="If set, recompiles MLIR -> .vmfb",
action=argparse.BooleanOptionalAction,
default=False,
)
parser.add_argument(
"--plugin-path",
help="path to executable plugin",
type=str,
default=None,
)
args = parser.parse_args()
print("args={}".format(args))
return args
if __name__ == "__main__":
args = parse_args()
tokenizer = AutoTokenizer.from_pretrained(args.model_name, use_fast=False)
opt_fs_name = "-".join(
"_".join(args.model_name.split("/")[1].split("-")).split(".")
)
vmfb_path = f"./{opt_fs_name}_causallm_{args.max_seq_len}_torch_cpu.vmfb"
if args.plugin_path is not None:
rt_flags = [f"--executable_plugin={args.plugin_path}"]
else:
rt_flags = []
opt_shark_module = SharkInference(
mlir_module=None, device="cpu-task", rt_flags=rt_flags
)
if os.path.isfile(vmfb_path):
opt_shark_module.load_module(vmfb_path)
else:
vmfb_path = opt_causallm.create_module(
args.model_name, tokenizer, "cpu-task", args
)
opt_shark_module.load_module(vmfb_path)
for prompt in opt_util.PROMPTS:
print("\n\nprompt: {}".format(prompt))
response = opt_causallm.generate_tokens(
opt_shark_module,
tokenizer,
prompt,
args.max_seq_len,
print_intermediate_results=False,
)
print("reponse: {}".format("".join(response)))

View File

@@ -22,6 +22,7 @@ import time
import numpy as np
from typing import Tuple
from opt_util import PROMPTS
from shark.shark_inference import SharkInference
from shark.shark_importer import import_with_fx
from transformers import AutoTokenizer, OPTForCausalLM
@@ -44,19 +45,6 @@ REPORT_LOAD_VIRTUAL_MEMORY_MB = "load_virtual_MB"
REPORT_RUN_PHYSICAL_MEMORY_MB = "run_physical_MB"
REPORT_RUN_VIRTUAL_MEMORY_MB = "run_virtual_MB"
PROMPTS = [
"What is the meaning of life?",
"Tell me something you don't know.",
"What does Xilinx do?",
"What is the mass of earth?",
"What is a poem?",
"What is recursion?",
"Tell me a one line joke.",
"Who is Gilgamesh?",
"Tell me something about cryptocurrency.",
"How did it all begin?",
]
ModelWrapper = collections.namedtuple("ModelWrapper", ["model", "tokenizer"])
@@ -72,7 +60,9 @@ def import_mlir_module(
device: str,
max_seq_len: int,
):
opt_base_model = OPTForCausalLM.from_pretrained(model_name)
opt_base_model = OPTForCausalLM.from_pretrained(
model_name, ignore_mismatched_sizes=True
)
opt_base_model.eval()
opt_model = OPTForCausalLMModel(opt_base_model)
encoded_inputs = tokenizer(
@@ -142,13 +132,14 @@ def create_vmfb_module(
def load_shark_model(
model_name: str,
token_model_name: str,
max_seq_len: int,
recompile_shark: bool,
plugin_path: str = [],
) -> ModelWrapper:
opt_fs_name = get_opt_fs_name(model_name)
vmfb_name = f"{opt_fs_name}_causallm_{max_seq_len}_torch_{DEVICE}_tiled_ukernels.vmfb"
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)
vmfb_name = f"{opt_fs_name}_causallm_{max_seq_len}_torch_{DEVICE}.vmfb"
tokenizer = AutoTokenizer.from_pretrained(token_model_name, use_fast=False)
if recompile_shark or not os.path.isfile(vmfb_name):
print(f"vmfb not found. compiling and saving to {vmfb_name}")
create_vmfb_module(
@@ -170,10 +161,12 @@ def run_shark_model(model_wrapper: ModelWrapper, tokens):
return model_wrapper.model("forward", tokens)
def load_huggingface_model(model_name: str) -> ModelWrapper:
def load_huggingface_model(
model_name: str, token_model_name: str
) -> ModelWrapper:
return ModelWrapper(
model=OPTForCausalLM.from_pretrained(model_name),
tokenizer=AutoTokenizer.from_pretrained(model_name),
tokenizer=AutoTokenizer.from_pretrained(token_model_name),
)
@@ -189,11 +182,14 @@ def save_json(data, filename):
def collect_huggingface_logits(
model_name: str, max_seq_len: int, to_save_json: bool
model_name: str,
token_model_name: str,
max_seq_len: int,
to_save_json: bool,
) -> Tuple[float, float]:
# Load
t0 = time.time()
model_wrapper = load_huggingface_model(model_name)
model_wrapper = load_huggingface_model(model_name, token_model_name)
load_time = time.time() - t0
print("--- Took {} seconds to load Huggingface.".format(load_time))
load_memory_info = get_memory_info()
@@ -237,6 +233,7 @@ def collect_huggingface_logits(
def collect_shark_logits(
model_name: str,
token_model_name: str,
max_seq_len: int,
recompile_shark: bool,
to_save_json: bool,
@@ -245,7 +242,7 @@ def collect_shark_logits(
# Load
t0 = time.time()
model_wrapper = load_shark_model(
model_name, max_seq_len, recompile_shark, plugin_path
model_name, token_model_name, max_seq_len, recompile_shark, plugin_path
)
load_time = time.time() - t0
print("--- Took {} seconds to load Shark.".format(load_time))
@@ -327,6 +324,11 @@ def parse_args():
"facebook/opt-350m",
"facebook/opt-1.3b",
"facebook/opt-6.7b",
"mit-han-lab/opt-125m-smoothquant",
"mit-han-lab/opt-1.3b-smoothquant",
"mit-han-lab/opt-2.7b-smoothquant",
"mit-han-lab/opt-6.7b-smoothquant",
"mit-han-lab/opt-13b-smoothquant",
],
default="facebook/opt-1.3b",
)
@@ -344,11 +346,17 @@ def parse_args():
default=PLATFORM_SHARK,
)
parser.add_argument(
"--plugin_path",
"--plugin-path",
help="path to executable plugin",
type=str,
default=None,
)
parser.add_argument(
"--token-model-name",
help="HF ID to create tokenizer.",
type=str,
default=None,
)
args = parser.parse_args()
print("args={}".format(args))
return args
@@ -356,9 +364,17 @@ def parse_args():
if __name__ == "__main__":
args = parse_args()
if args.token_model_name == None:
if "smoothquant" in args.model_name:
args.token_model_name = (
f"facebook/opt-{args.model_name.split('-')[3]}"
)
else:
args.token_model_name = args.model_name
if args.platform == PLATFORM_SHARK:
shark_report = collect_shark_logits(
args.model_name,
args.token_model_name,
args.max_seq_len,
args.recompile_shark,
args.save_json,
@@ -367,6 +383,9 @@ if __name__ == "__main__":
print("# Summary: {}".format(json.dumps(shark_report)))
else:
huggingface_report = collect_huggingface_logits(
args.model_name, args.max_seq_len, args.save_json
args.model_name,
args.token_model_name,
args.max_seq_len,
args.save_json,
)
print("# Summary: {}".format(json.dumps(huggingface_report)))

View File

@@ -0,0 +1,12 @@
PROMPTS = [
"What is the meaning of life?",
"Tell me something you don't know.",
"What does Xilinx do?",
"What is the mass of earth?",
"What is a poem?",
"What is recursion?",
"Tell me a one line joke.",
"Who is Gilgamesh?",
"Tell me something about cryptocurrency.",
"How did it all begin?",
]

View File

@@ -1,4 +1,3 @@
from shark.iree_utils._common import check_device_drivers, device_driver_info
from shark.shark_inference import SharkInference
from shark.shark_downloader import download_model
from tank.test_utils import get_valid_test_params, shark_test_name_func

View File

@@ -44,7 +44,7 @@ class TapasBaseModuleTest(unittest.TestCase):
self.module_tester.create_and_check_module(dynamic, device)
@pytest.mark.skipif(
check_device_drivers("cuda"), reason=device_driver_info("gpu")
check_device_drivers("cuda"), reason=device_driver_info("cuda")
)
def test_module_static_cuda(self):
dynamic = False