mirror of
https://github.com/nod-ai/SHARK-Studio.git
synced 2026-04-20 03:00:34 -04:00
Compare commits
17 Commits
20231021.9
...
20231104.1
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
322874f7f9 | ||
|
|
5001db3415 | ||
|
|
71846344a2 | ||
|
|
72e27c96fc | ||
|
|
7963abb8ec | ||
|
|
98244232dd | ||
|
|
679a452139 | ||
|
|
72c0a8abc8 | ||
|
|
ea920f2955 | ||
|
|
486202377a | ||
|
|
0c38c33d0a | ||
|
|
841773fa32 | ||
|
|
0361db46f9 | ||
|
|
a012433ffd | ||
|
|
5061193da3 | ||
|
|
bff48924be | ||
|
|
825b36cbdd |
@@ -70,7 +70,7 @@ class CompiledLMHeadEmbeddingLayer(torch.nn.Module):
|
||||
|
||||
|
||||
class DecoderLayer(torch.nn.Module):
|
||||
def __init__(self, decoder_layer_model):
|
||||
def __init__(self, decoder_layer_model, falcon_variant):
|
||||
super().__init__()
|
||||
self.model = decoder_layer_model
|
||||
|
||||
@@ -85,9 +85,15 @@ class DecoderLayer(torch.nn.Module):
|
||||
|
||||
|
||||
class CompiledDecoderLayer(torch.nn.Module):
|
||||
def __init__(self, shark_decoder_layer_module):
|
||||
def __init__(
|
||||
self, layer_id, device_idx, falcon_variant, device, precision
|
||||
):
|
||||
super().__init__()
|
||||
self.model = shark_decoder_layer_module
|
||||
self.layer_id = layer_id
|
||||
self.device_index = device_idx
|
||||
self.falcon_variant = falcon_variant
|
||||
self.device = device
|
||||
self.precision = precision
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -99,6 +105,26 @@ class CompiledDecoderLayer(torch.nn.Module):
|
||||
use_cache: bool = False,
|
||||
output_attentions: bool = False,
|
||||
):
|
||||
import gc
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
from pathlib import Path
|
||||
from apps.language_models.utils import get_vmfb_from_path
|
||||
|
||||
self.falcon_vmfb_path = Path(
|
||||
f"falcon_{self.falcon_variant}_layer_{self.layer_id}_{self.precision}_{self.device}.vmfb"
|
||||
)
|
||||
print("vmfb path for layer: ", self.falcon_vmfb_path)
|
||||
self.model = get_vmfb_from_path(
|
||||
self.falcon_vmfb_path,
|
||||
self.device,
|
||||
"linalg",
|
||||
device_id=self.device_index,
|
||||
)
|
||||
if self.model is None:
|
||||
raise ValueError("Layer vmfb not found")
|
||||
|
||||
hidden_states = hidden_states.to(torch.float32).detach().numpy()
|
||||
attention_mask = attention_mask.detach().numpy()
|
||||
|
||||
@@ -112,6 +138,8 @@ class CompiledDecoderLayer(torch.nn.Module):
|
||||
attention_mask,
|
||||
),
|
||||
)
|
||||
del self.model
|
||||
|
||||
return tuple(
|
||||
[
|
||||
torch.tensor(new_hidden_states),
|
||||
@@ -125,6 +153,314 @@ class CompiledDecoderLayer(torch.nn.Module):
|
||||
)
|
||||
|
||||
|
||||
class EightDecoderLayer(torch.nn.Module):
|
||||
def __init__(self, decoder_layer_model, falcon_variant):
|
||||
super().__init__()
|
||||
self.model = decoder_layer_model
|
||||
self.falcon_variant = falcon_variant
|
||||
|
||||
def forward(self, hidden_states, attention_mask):
|
||||
new_pkvs = []
|
||||
for layer in self.model:
|
||||
outputs = layer(
|
||||
hidden_states=hidden_states,
|
||||
alibi=None,
|
||||
attention_mask=attention_mask,
|
||||
use_cache=True,
|
||||
)
|
||||
hidden_states = outputs[0]
|
||||
new_pkvs.append(
|
||||
(
|
||||
outputs[-1][0],
|
||||
outputs[-1][1],
|
||||
)
|
||||
)
|
||||
if self.falcon_variant == "7b":
|
||||
(
|
||||
(new_pkv00, new_pkv01),
|
||||
(new_pkv10, new_pkv11),
|
||||
(new_pkv20, new_pkv21),
|
||||
(new_pkv30, new_pkv31),
|
||||
(new_pkv40, new_pkv41),
|
||||
(new_pkv50, new_pkv51),
|
||||
(new_pkv60, new_pkv61),
|
||||
(new_pkv70, new_pkv71),
|
||||
) = new_pkvs
|
||||
result = (
|
||||
hidden_states,
|
||||
new_pkv00,
|
||||
new_pkv01,
|
||||
new_pkv10,
|
||||
new_pkv11,
|
||||
new_pkv20,
|
||||
new_pkv21,
|
||||
new_pkv30,
|
||||
new_pkv31,
|
||||
new_pkv40,
|
||||
new_pkv41,
|
||||
new_pkv50,
|
||||
new_pkv51,
|
||||
new_pkv60,
|
||||
new_pkv61,
|
||||
new_pkv70,
|
||||
new_pkv71,
|
||||
)
|
||||
elif self.falcon_variant == "180b":
|
||||
(
|
||||
(new_pkv00, new_pkv01),
|
||||
(new_pkv10, new_pkv11),
|
||||
(new_pkv20, new_pkv21),
|
||||
(new_pkv30, new_pkv31),
|
||||
(new_pkv40, new_pkv41),
|
||||
(new_pkv50, new_pkv51),
|
||||
(new_pkv60, new_pkv61),
|
||||
(new_pkv70, new_pkv71),
|
||||
(new_pkv80, new_pkv81),
|
||||
(new_pkv90, new_pkv91),
|
||||
(new_pkv100, new_pkv101),
|
||||
(new_pkv110, new_pkv111),
|
||||
(new_pkv120, new_pkv121),
|
||||
(new_pkv130, new_pkv131),
|
||||
(new_pkv140, new_pkv141),
|
||||
(new_pkv150, new_pkv151),
|
||||
(new_pkv160, new_pkv161),
|
||||
(new_pkv170, new_pkv171),
|
||||
(new_pkv180, new_pkv181),
|
||||
(new_pkv190, new_pkv191),
|
||||
) = new_pkvs
|
||||
result = (
|
||||
hidden_states,
|
||||
new_pkv00,
|
||||
new_pkv01,
|
||||
new_pkv10,
|
||||
new_pkv11,
|
||||
new_pkv20,
|
||||
new_pkv21,
|
||||
new_pkv30,
|
||||
new_pkv31,
|
||||
new_pkv40,
|
||||
new_pkv41,
|
||||
new_pkv50,
|
||||
new_pkv51,
|
||||
new_pkv60,
|
||||
new_pkv61,
|
||||
new_pkv70,
|
||||
new_pkv71,
|
||||
new_pkv80,
|
||||
new_pkv81,
|
||||
new_pkv90,
|
||||
new_pkv91,
|
||||
new_pkv100,
|
||||
new_pkv101,
|
||||
new_pkv110,
|
||||
new_pkv111,
|
||||
new_pkv120,
|
||||
new_pkv121,
|
||||
new_pkv130,
|
||||
new_pkv131,
|
||||
new_pkv140,
|
||||
new_pkv141,
|
||||
new_pkv150,
|
||||
new_pkv151,
|
||||
new_pkv160,
|
||||
new_pkv161,
|
||||
new_pkv170,
|
||||
new_pkv171,
|
||||
new_pkv180,
|
||||
new_pkv181,
|
||||
new_pkv190,
|
||||
new_pkv191,
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Unsupported Falcon variant: ", self.falcon_variant
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
class CompiledEightDecoderLayer(torch.nn.Module):
|
||||
def __init__(
|
||||
self, layer_id, device_idx, falcon_variant, device, precision
|
||||
):
|
||||
super().__init__()
|
||||
self.layer_id = layer_id
|
||||
self.device_index = device_idx
|
||||
self.falcon_variant = falcon_variant
|
||||
self.device = device
|
||||
self.precision = precision
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: torch.Tensor,
|
||||
alibi: torch.Tensor = None,
|
||||
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
head_mask: Optional[torch.Tensor] = None,
|
||||
use_cache: bool = False,
|
||||
output_attentions: bool = False,
|
||||
):
|
||||
import gc
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
from pathlib import Path
|
||||
from apps.language_models.utils import get_vmfb_from_path
|
||||
|
||||
self.falcon_vmfb_path = Path(
|
||||
f"falcon_{self.falcon_variant}_layer_{self.layer_id}_{self.precision}_{self.device}.vmfb"
|
||||
)
|
||||
print("vmfb path for layer: ", self.falcon_vmfb_path)
|
||||
self.model = get_vmfb_from_path(
|
||||
self.falcon_vmfb_path,
|
||||
self.device,
|
||||
"linalg",
|
||||
device_id=self.device_index,
|
||||
)
|
||||
if self.model is None:
|
||||
raise ValueError("Layer vmfb not found")
|
||||
|
||||
hidden_states = hidden_states.to(torch.float32).detach().numpy()
|
||||
attention_mask = attention_mask.detach().numpy()
|
||||
|
||||
if alibi is not None or layer_past is not None:
|
||||
raise ValueError("Past Key Values and alibi should be None")
|
||||
else:
|
||||
output = self.model(
|
||||
"forward",
|
||||
(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
),
|
||||
)
|
||||
del self.model
|
||||
|
||||
if self.falcon_variant == "7b":
|
||||
result = (
|
||||
torch.tensor(output[0]),
|
||||
(
|
||||
torch.tensor(output[1]),
|
||||
torch.tensor(output[2]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[3]),
|
||||
torch.tensor(output[4]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[5]),
|
||||
torch.tensor(output[6]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[7]),
|
||||
torch.tensor(output[8]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[9]),
|
||||
torch.tensor(output[10]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[11]),
|
||||
torch.tensor(output[12]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[13]),
|
||||
torch.tensor(output[14]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[15]),
|
||||
torch.tensor(output[16]),
|
||||
),
|
||||
)
|
||||
elif self.falcon_variant == "180b":
|
||||
result = (
|
||||
torch.tensor(output[0]),
|
||||
(
|
||||
torch.tensor(output[1]),
|
||||
torch.tensor(output[2]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[3]),
|
||||
torch.tensor(output[4]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[5]),
|
||||
torch.tensor(output[6]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[7]),
|
||||
torch.tensor(output[8]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[9]),
|
||||
torch.tensor(output[10]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[11]),
|
||||
torch.tensor(output[12]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[13]),
|
||||
torch.tensor(output[14]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[15]),
|
||||
torch.tensor(output[16]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[17]),
|
||||
torch.tensor(output[18]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[19]),
|
||||
torch.tensor(output[20]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[21]),
|
||||
torch.tensor(output[22]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[23]),
|
||||
torch.tensor(output[24]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[25]),
|
||||
torch.tensor(output[26]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[27]),
|
||||
torch.tensor(output[28]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[29]),
|
||||
torch.tensor(output[30]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[31]),
|
||||
torch.tensor(output[32]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[33]),
|
||||
torch.tensor(output[34]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[35]),
|
||||
torch.tensor(output[36]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[37]),
|
||||
torch.tensor(output[38]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[39]),
|
||||
torch.tensor(output[40]),
|
||||
),
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Unsupported Falcon variant: ", self.falcon_variant
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
class ShardedFalconModel:
|
||||
def __init__(self, model, layers, word_embeddings, ln_f, lm_head):
|
||||
super().__init__()
|
||||
|
||||
@@ -7,7 +7,9 @@ from apps.language_models.src.model_wrappers.falcon_sharded_model import (
|
||||
LMHeadEmbeddingLayer,
|
||||
CompiledLMHeadEmbeddingLayer,
|
||||
DecoderLayer,
|
||||
EightDecoderLayer,
|
||||
CompiledDecoderLayer,
|
||||
CompiledEightDecoderLayer,
|
||||
ShardedFalconModel,
|
||||
)
|
||||
from apps.language_models.src.pipelines.SharkLLMBase import SharkLLMBase
|
||||
@@ -27,12 +29,13 @@ from transformers.generation import (
|
||||
StoppingCriteriaList,
|
||||
)
|
||||
import copy
|
||||
|
||||
import time
|
||||
import re
|
||||
import torch
|
||||
import torch_mlir
|
||||
import os
|
||||
import argparse
|
||||
import gc
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
prog="falcon runner",
|
||||
@@ -42,6 +45,12 @@ parser = argparse.ArgumentParser(
|
||||
parser.add_argument(
|
||||
"--falcon_variant_to_use", default="7b", help="7b, 40b, 180b"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--compressed",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="Do the compression of sharded layers",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--precision", "-p", default="fp16", choices=["fp32", "fp16", "int4"]
|
||||
)
|
||||
@@ -127,7 +136,7 @@ class ShardedFalcon(SharkLLMBase):
|
||||
self.debug = debug
|
||||
self.tokenizer = self.get_tokenizer()
|
||||
self.src_model = self.get_src_model()
|
||||
self.shark_model = self.compile()
|
||||
self.shark_model = self.compile(compressed=args.compressed)
|
||||
|
||||
def get_tokenizer(self):
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
@@ -150,7 +159,7 @@ class ShardedFalcon(SharkLLMBase):
|
||||
quantization_config = GPTQConfig(bits=4, disable_exllama=True)
|
||||
kwargs["quantization_config"] = quantization_config
|
||||
kwargs["load_gptq_on_cpu"] = True
|
||||
kwargs["device_map"] = "cpu" if self.device == "cpu" else "cuda:0"
|
||||
kwargs["device_map"] = "cpu"
|
||||
falcon_model = AutoModelForCausalLM.from_pretrained(
|
||||
self.hf_model_path, **kwargs
|
||||
)
|
||||
@@ -158,7 +167,9 @@ class ShardedFalcon(SharkLLMBase):
|
||||
falcon_model = falcon_model.to(torch.float32)
|
||||
return falcon_model
|
||||
|
||||
def compile_layer(self, layer, falconCompileInput, layer_id):
|
||||
def compile_layer(
|
||||
self, layer, falconCompileInput, layer_id, device_idx=None
|
||||
):
|
||||
self.falcon_mlir_path = Path(
|
||||
f"falcon_{args.falcon_variant_to_use}_layer_{layer_id}_{self.precision}.mlir"
|
||||
)
|
||||
@@ -177,10 +188,13 @@ class ShardedFalcon(SharkLLMBase):
|
||||
single_file=True,
|
||||
)
|
||||
vmfb = get_vmfb_from_path(
|
||||
self.falcon_vmfb_path, self.device, "linalg"
|
||||
self.falcon_vmfb_path,
|
||||
self.device,
|
||||
"linalg",
|
||||
device_id=device_idx,
|
||||
)
|
||||
if vmfb is not None:
|
||||
return vmfb
|
||||
return vmfb, device_idx
|
||||
|
||||
print(f"[DEBUG] vmfb not found at {self.falcon_vmfb_path.absolute()}")
|
||||
if self.falcon_mlir_path.exists():
|
||||
@@ -215,7 +229,7 @@ class ShardedFalcon(SharkLLMBase):
|
||||
f16_input_mask = [False]
|
||||
elif layer_id in ["ln_f", "lm_head"]:
|
||||
f16_input_mask = [True]
|
||||
elif type(layer_id) == int:
|
||||
elif "_" in layer_id or type(layer_id) == int:
|
||||
f16_input_mask = [True, False]
|
||||
else:
|
||||
raise ValueError("Unsupported layer: ", layer_id)
|
||||
@@ -257,6 +271,7 @@ class ShardedFalcon(SharkLLMBase):
|
||||
mlir_module=self.falcon_mlir_path,
|
||||
device=self.device,
|
||||
mlir_dialect="linalg",
|
||||
device_idx=device_idx,
|
||||
)
|
||||
path = shark_module.save_module(
|
||||
self.falcon_vmfb_path.parent.absolute(),
|
||||
@@ -276,27 +291,43 @@ class ShardedFalcon(SharkLLMBase):
|
||||
print("Saved falcon vmfb at ", str(path))
|
||||
shark_module.load_module(path)
|
||||
|
||||
return shark_module
|
||||
return shark_module, device_idx
|
||||
|
||||
def compile(self):
|
||||
def compile(self, compressed=False):
|
||||
sample_input_ids = torch.zeros([100], dtype=torch.int64)
|
||||
sample_attention_mask = torch.zeros(
|
||||
[1, 1, 100, 100], dtype=torch.float32
|
||||
)
|
||||
num_group_layers = 1
|
||||
if "7b" in self.model_name:
|
||||
num_in_features = 4544
|
||||
if compressed:
|
||||
num_group_layers = 8
|
||||
else:
|
||||
num_in_features = 14848
|
||||
sample_attention_mask = sample_attention_mask.to(dtype=torch.bool)
|
||||
if compressed:
|
||||
num_group_layers = 20
|
||||
|
||||
sample_hidden_states = torch.zeros(
|
||||
[1, 100, num_in_features], dtype=torch.float32
|
||||
)
|
||||
|
||||
# Determine number of available devices
|
||||
num_devices = 1
|
||||
if self.device == "rocm":
|
||||
import iree.runtime as ireert
|
||||
|
||||
haldriver = ireert.get_driver(self.device)
|
||||
num_devices = len(haldriver.query_available_devices())
|
||||
|
||||
lm_head = LMHeadEmbeddingLayer(self.src_model.lm_head)
|
||||
print("Compiling Layer lm_head")
|
||||
shark_lm_head = self.compile_layer(
|
||||
lm_head, [sample_hidden_states], "lm_head"
|
||||
shark_lm_head, _ = self.compile_layer(
|
||||
lm_head,
|
||||
[sample_hidden_states],
|
||||
"lm_head",
|
||||
device_idx=0 % num_devices if self.device == "rocm" else None,
|
||||
)
|
||||
shark_lm_head = CompiledLMHeadEmbeddingLayer(shark_lm_head)
|
||||
|
||||
@@ -304,8 +335,11 @@ class ShardedFalcon(SharkLLMBase):
|
||||
self.src_model.transformer.word_embeddings
|
||||
)
|
||||
print("Compiling Layer word_embeddings")
|
||||
shark_word_embedding = self.compile_layer(
|
||||
word_embedding, [sample_input_ids], "word_embeddings"
|
||||
shark_word_embedding, _ = self.compile_layer(
|
||||
word_embedding,
|
||||
[sample_input_ids],
|
||||
"word_embeddings",
|
||||
device_idx=1 % num_devices if self.device == "rocm" else None,
|
||||
)
|
||||
shark_word_embedding = CompiledWordEmbeddingsLayer(
|
||||
shark_word_embedding
|
||||
@@ -313,20 +347,56 @@ class ShardedFalcon(SharkLLMBase):
|
||||
|
||||
ln_f = LNFEmbeddingLayer(self.src_model.transformer.ln_f)
|
||||
print("Compiling Layer ln_f")
|
||||
shark_ln_f = self.compile_layer(ln_f, [sample_hidden_states], "ln_f")
|
||||
shark_ln_f, _ = self.compile_layer(
|
||||
ln_f,
|
||||
[sample_hidden_states],
|
||||
"ln_f",
|
||||
device_idx=2 % num_devices if self.device == "rocm" else None,
|
||||
)
|
||||
shark_ln_f = CompiledLNFEmbeddingLayer(shark_ln_f)
|
||||
|
||||
shark_layers = []
|
||||
for i in range(len(self.src_model.transformer.h)):
|
||||
print("Compiling Layer {}".format(i))
|
||||
layer_i = self.src_model.transformer.h[i]
|
||||
pytorch_layer_i = DecoderLayer(layer_i)
|
||||
shark_module = self.compile_layer(
|
||||
for i in range(
|
||||
int(len(self.src_model.transformer.h) / num_group_layers)
|
||||
):
|
||||
device_idx = i % num_devices if self.device == "rocm" else None
|
||||
layer_id = i
|
||||
pytorch_class = DecoderLayer
|
||||
compiled_class = CompiledDecoderLayer
|
||||
if compressed:
|
||||
layer_id = (
|
||||
str(i * num_group_layers)
|
||||
+ "_"
|
||||
+ str((i + 1) * num_group_layers)
|
||||
)
|
||||
pytorch_class = EightDecoderLayer
|
||||
compiled_class = CompiledEightDecoderLayer
|
||||
|
||||
print("Compiling Layer {}".format(layer_id))
|
||||
if compressed:
|
||||
layer_i = self.src_model.transformer.h[
|
||||
i * num_group_layers : (i + 1) * num_group_layers
|
||||
]
|
||||
else:
|
||||
layer_i = self.src_model.transformer.h[i]
|
||||
|
||||
pytorch_layer_i = pytorch_class(
|
||||
layer_i, args.falcon_variant_to_use
|
||||
)
|
||||
shark_module, device_idx = self.compile_layer(
|
||||
pytorch_layer_i,
|
||||
[sample_hidden_states, sample_attention_mask],
|
||||
i,
|
||||
layer_id,
|
||||
device_idx=device_idx,
|
||||
)
|
||||
del shark_module
|
||||
shark_layer_i = compiled_class(
|
||||
layer_id,
|
||||
device_idx,
|
||||
args.falcon_variant_to_use,
|
||||
self.device,
|
||||
self.precision,
|
||||
)
|
||||
shark_layer_i = CompiledDecoderLayer(shark_module)
|
||||
shark_layers.append(shark_layer_i)
|
||||
|
||||
sharded_model = ShardedFalconModel(
|
||||
@@ -355,9 +425,6 @@ class ShardedFalcon(SharkLLMBase):
|
||||
if input_ids.shape[1] == 0:
|
||||
input_ids = None
|
||||
attention_mask = None
|
||||
in_b = 1
|
||||
else:
|
||||
in_b = input_ids.shape[0]
|
||||
|
||||
generate_kwargs = {
|
||||
"max_length": self.max_num_tokens,
|
||||
@@ -387,7 +454,6 @@ class ShardedFalcon(SharkLLMBase):
|
||||
) = self.src_model._prepare_model_inputs(
|
||||
None, generation_config.bos_token_id, model_kwargs
|
||||
)
|
||||
batch_size = inputs_tensor.shape[0]
|
||||
|
||||
model_kwargs["output_attentions"] = generation_config.output_attentions
|
||||
model_kwargs[
|
||||
@@ -438,8 +504,6 @@ class ShardedFalcon(SharkLLMBase):
|
||||
self.eos_token_id = eos_token_id
|
||||
|
||||
output_scores = generation_config.output_scores # False
|
||||
output_attentions = generation_config.output_attentions # False
|
||||
output_hidden_states = generation_config.output_hidden_states # False
|
||||
return_dict_in_generate = (
|
||||
generation_config.return_dict_in_generate # False
|
||||
)
|
||||
@@ -448,15 +512,6 @@ class ShardedFalcon(SharkLLMBase):
|
||||
self.scores = (
|
||||
() if (return_dict_in_generate and output_scores) else None
|
||||
)
|
||||
decoder_attentions = (
|
||||
() if (return_dict_in_generate and output_attentions) else None
|
||||
)
|
||||
cross_attentions = (
|
||||
() if (return_dict_in_generate and output_attentions) else None
|
||||
)
|
||||
decoder_hidden_states = (
|
||||
() if (return_dict_in_generate and output_hidden_states) else None
|
||||
)
|
||||
|
||||
# keep track of which sequences are already finished
|
||||
self.unfinished_sequences = torch.ones(
|
||||
@@ -465,7 +520,11 @@ class ShardedFalcon(SharkLLMBase):
|
||||
|
||||
all_text = prompt
|
||||
|
||||
start = time.time()
|
||||
count = 0
|
||||
for i in range(self.max_num_tokens - 1):
|
||||
count = count + 1
|
||||
|
||||
next_token = self.generate_new_token()
|
||||
new_word = self.tokenizer.decode(
|
||||
next_token.cpu().numpy(),
|
||||
@@ -477,6 +536,7 @@ class ShardedFalcon(SharkLLMBase):
|
||||
all_text = all_text + new_word
|
||||
|
||||
print(f"{new_word}", end="", flush=True)
|
||||
print(f"{all_text}", end="", flush=True)
|
||||
|
||||
# if eos_token was found in one sentence, set sentence to finished
|
||||
if self.eos_token_id_tensor is not None:
|
||||
@@ -492,6 +552,13 @@ class ShardedFalcon(SharkLLMBase):
|
||||
):
|
||||
break
|
||||
|
||||
end = time.time()
|
||||
print(
|
||||
"\n\nTime taken is {:.2f} seconds/token\n".format(
|
||||
(end - start) / count
|
||||
)
|
||||
)
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
|
||||
@@ -1023,8 +1090,6 @@ if __name__ == "__main__":
|
||||
precision=args.precision,
|
||||
)
|
||||
|
||||
import gc
|
||||
|
||||
default_prompt_text = "Girafatron is obsessed with giraffes, the most glorious animal on the face of this Earth. Giraftron believes all other animals are irrelevant when compared to the glorious majesty of the giraffe.\nDaniel: Hello, Girafatron!\nGirafatron:"
|
||||
continue_execution = True
|
||||
|
||||
|
||||
@@ -53,6 +53,7 @@ datas += collect_data_files("jsonschema_specifications")
|
||||
datas += collect_data_files("cpuinfo")
|
||||
datas += collect_data_files("langchain")
|
||||
datas += collect_data_files("cv2")
|
||||
datas += collect_data_files("einops")
|
||||
datas += [
|
||||
("src/utils/resources/prompts.json", "resources"),
|
||||
("src/utils/resources/model_db.json", "resources"),
|
||||
|
||||
@@ -11,12 +11,12 @@
|
||||
"untuned": {
|
||||
"fp16": {
|
||||
"default_compilation_flags": [
|
||||
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-flow-detach-elementwise-from-named-ops,iree-flow-convert-1x1-filter-conv2d-to-matmul,iree-preprocessing-convert-conv2d-to-img2col,iree-preprocessing-pad-linalg-ops{pad-size=32}))"
|
||||
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-global-opt-detach-elementwise-from-named-ops,iree-global-opt-convert-1x1-filter-conv2d-to-matmul,iree-preprocessing-convert-conv2d-to-img2col,iree-preprocessing-pad-linalg-ops{pad-size=32}))"
|
||||
]
|
||||
},
|
||||
"fp32": {
|
||||
"default_compilation_flags": [
|
||||
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-flow-detach-elementwise-from-named-ops,iree-flow-convert-1x1-filter-conv2d-to-matmul,iree-preprocessing-convert-conv2d-to-img2col,iree-preprocessing-pad-linalg-ops{pad-size=16}))"
|
||||
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-global-opt-detach-elementwise-from-named-ops,iree-global-opt-convert-1x1-filter-conv2d-to-matmul,iree-preprocessing-convert-conv2d-to-img2col,iree-preprocessing-pad-linalg-ops{pad-size=16}))"
|
||||
]
|
||||
}
|
||||
}
|
||||
@@ -28,7 +28,7 @@
|
||||
"specified_compilation_flags": {
|
||||
"cuda": [],
|
||||
"default_device": [
|
||||
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-flow-detach-elementwise-from-named-ops,iree-flow-convert-1x1-filter-conv2d-to-matmul,iree-preprocessing-convert-conv2d-to-img2col,iree-preprocessing-pad-linalg-ops{pad-size=32},iree-linalg-ext-convert-conv2d-to-winograd))"
|
||||
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-global-opt-detach-elementwise-from-named-ops,iree-global-opt-convert-1x1-filter-conv2d-to-matmul,iree-preprocessing-convert-conv2d-to-img2col,iree-preprocessing-pad-linalg-ops{pad-size=32},iree-linalg-ext-convert-conv2d-to-winograd))"
|
||||
]
|
||||
}
|
||||
},
|
||||
@@ -37,7 +37,7 @@
|
||||
"specified_compilation_flags": {
|
||||
"cuda": [],
|
||||
"default_device": [
|
||||
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-flow-detach-elementwise-from-named-ops,iree-flow-convert-1x1-filter-conv2d-to-matmul,iree-preprocessing-convert-conv2d-to-img2col,iree-preprocessing-pad-linalg-ops{pad-size=16},iree-linalg-ext-convert-conv2d-to-winograd))"
|
||||
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-global-opt-detach-elementwise-from-named-ops,iree-global-opt-convert-1x1-filter-conv2d-to-matmul,iree-preprocessing-convert-conv2d-to-img2col,iree-preprocessing-pad-linalg-ops{pad-size=16},iree-linalg-ext-convert-conv2d-to-winograd))"
|
||||
]
|
||||
}
|
||||
}
|
||||
@@ -45,12 +45,12 @@
|
||||
"untuned": {
|
||||
"fp16": {
|
||||
"default_compilation_flags": [
|
||||
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-flow-detach-elementwise-from-named-ops,iree-flow-convert-1x1-filter-conv2d-to-matmul,iree-preprocessing-convert-conv2d-to-img2col,iree-preprocessing-pad-linalg-ops{pad-size=32},iree-linalg-ext-convert-conv2d-to-winograd))"
|
||||
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-global-opt-detach-elementwise-from-named-ops,iree-global-opt-convert-1x1-filter-conv2d-to-matmul,iree-preprocessing-convert-conv2d-to-img2col,iree-preprocessing-pad-linalg-ops{pad-size=32},iree-linalg-ext-convert-conv2d-to-winograd))"
|
||||
]
|
||||
},
|
||||
"fp32": {
|
||||
"default_compilation_flags": [
|
||||
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-flow-detach-elementwise-from-named-ops,iree-flow-convert-1x1-filter-conv2d-to-matmul,iree-preprocessing-convert-conv2d-to-img2col,iree-preprocessing-pad-linalg-ops{pad-size=16},iree-linalg-ext-convert-conv2d-to-winograd))"
|
||||
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-global-opt-detach-elementwise-from-named-ops,iree-global-opt-convert-1x1-filter-conv2d-to-matmul,iree-preprocessing-convert-conv2d-to-img2col,iree-preprocessing-pad-linalg-ops{pad-size=16},iree-linalg-ext-convert-conv2d-to-winograd))"
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
@@ -203,8 +203,8 @@ def dump_after_mlir(input_mlir, use_winograd):
|
||||
if use_winograd:
|
||||
preprocess_flag = (
|
||||
"--iree-preprocessing-pass-pipeline=builtin.module"
|
||||
"(func.func(iree-flow-detach-elementwise-from-named-ops,"
|
||||
"iree-flow-convert-1x1-filter-conv2d-to-matmul,"
|
||||
"(func.func(iree-global-opt-detach-elementwise-from-named-ops,"
|
||||
"iree-global-opt-convert-1x1-filter-conv2d-to-matmul,"
|
||||
"iree-preprocessing-convert-conv2d-to-img2col,"
|
||||
"iree-preprocessing-pad-linalg-ops{pad-size=32},"
|
||||
"iree-linalg-ext-convert-conv2d-to-winograd))"
|
||||
@@ -212,8 +212,8 @@ def dump_after_mlir(input_mlir, use_winograd):
|
||||
else:
|
||||
preprocess_flag = (
|
||||
"--iree-preprocessing-pass-pipeline=builtin.module"
|
||||
"(func.func(iree-flow-detach-elementwise-from-named-ops,"
|
||||
"iree-flow-convert-1x1-filter-conv2d-to-matmul,"
|
||||
"(func.func(iree-global-opt-detach-elementwise-from-named-ops,"
|
||||
"iree-global-opt-convert-1x1-filter-conv2d-to-matmul,"
|
||||
"iree-preprocessing-convert-conv2d-to-img2col,"
|
||||
"iree-preprocessing-pad-linalg-ops{pad-size=32}))"
|
||||
)
|
||||
|
||||
@@ -422,7 +422,7 @@ p.add_argument(
|
||||
|
||||
p.add_argument(
|
||||
"--use_stencil",
|
||||
choices=["canny", "openpose", "scribble"],
|
||||
choices=["canny", "openpose", "scribble", "zoedepth"],
|
||||
help="Enable the stencil feature.",
|
||||
)
|
||||
|
||||
@@ -725,6 +725,17 @@ p.add_argument(
|
||||
help="Specifies whether the docuchat's web version is running or not.",
|
||||
)
|
||||
|
||||
##############################################################################
|
||||
# rocm Flags
|
||||
##############################################################################
|
||||
|
||||
p.add_argument(
|
||||
"--iree_rocm_target_chip",
|
||||
type=str,
|
||||
default="gfx1100",
|
||||
help="Add the rocm device architecture ex gfx1100, gfx90a, etc. Default gfx1100",
|
||||
)
|
||||
|
||||
args, unknown = p.parse_known_args()
|
||||
if args.import_debug:
|
||||
os.environ["IREE_SAVE_TEMPS"] = os.path.join(
|
||||
|
||||
@@ -1,2 +1,3 @@
|
||||
from apps.stable_diffusion.src.utils.stencils.canny import CannyDetector
|
||||
from apps.stable_diffusion.src.utils.stencils.openpose import OpenposeDetector
|
||||
from apps.stable_diffusion.src.utils.stencils.zoe import ZoeDetector
|
||||
|
||||
@@ -4,6 +4,7 @@ import torch
|
||||
from apps.stable_diffusion.src.utils.stencils import (
|
||||
CannyDetector,
|
||||
OpenposeDetector,
|
||||
ZoeDetector,
|
||||
)
|
||||
|
||||
stencil = {}
|
||||
@@ -117,6 +118,9 @@ def controlnet_hint_conversion(
|
||||
case "scribble":
|
||||
print("Working with scribble")
|
||||
controlnet_hint = hint_scribble(image)
|
||||
case "zoedepth":
|
||||
print("Working with ZoeDepth")
|
||||
controlnet_hint = hint_zoedepth(image)
|
||||
case _:
|
||||
return None
|
||||
controlnet_hint = controlnet_hint_shaping(
|
||||
@@ -127,7 +131,7 @@ def controlnet_hint_conversion(
|
||||
|
||||
stencil_to_model_id_map = {
|
||||
"canny": "lllyasviel/control_v11p_sd15_canny",
|
||||
"depth": "lllyasviel/control_v11p_sd15_depth",
|
||||
"zoedepth": "lllyasviel/control_v11f1p_sd15_depth",
|
||||
"hed": "lllyasviel/sd-controlnet-hed",
|
||||
"mlsd": "lllyasviel/control_v11p_sd15_mlsd",
|
||||
"normal": "lllyasviel/control_v11p_sd15_normalbae",
|
||||
@@ -184,3 +188,16 @@ def hint_scribble(image: Image.Image):
|
||||
detected_map = np.zeros_like(input_image, dtype=np.uint8)
|
||||
detected_map[np.min(input_image, axis=2) < 127] = 255
|
||||
return detected_map
|
||||
|
||||
|
||||
# Stencil 4. Depth (Only Zoe Preprocessing)
|
||||
def hint_zoedepth(image: Image.Image):
|
||||
with torch.no_grad():
|
||||
input_image = np.array(image)
|
||||
|
||||
if not "depth" in stencil:
|
||||
stencil["depth"] = ZoeDetector()
|
||||
|
||||
detected_map = stencil["depth"](input_image)
|
||||
detected_map = HWC3(detected_map)
|
||||
return detected_map
|
||||
|
||||
58
apps/stable_diffusion/src/utils/stencils/zoe/__init__.py
Normal file
58
apps/stable_diffusion/src/utils/stencils/zoe/__init__.py
Normal file
@@ -0,0 +1,58 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
from pathlib import Path
|
||||
import requests
|
||||
|
||||
|
||||
from einops import rearrange
|
||||
|
||||
remote_model_path = (
|
||||
"https://huggingface.co/lllyasviel/Annotators/resolve/main/ZoeD_M12_N.pt"
|
||||
)
|
||||
|
||||
|
||||
class ZoeDetector:
|
||||
def __init__(self):
|
||||
cwd = Path.cwd()
|
||||
ckpt_path = Path(cwd, "stencil_annotator")
|
||||
ckpt_path.mkdir(parents=True, exist_ok=True)
|
||||
modelpath = ckpt_path / "ZoeD_M12_N.pt"
|
||||
|
||||
with requests.get(remote_model_path, stream=True) as r:
|
||||
r.raise_for_status()
|
||||
with open(modelpath, "wb") as f:
|
||||
for chunk in r.iter_content(chunk_size=8192):
|
||||
f.write(chunk)
|
||||
|
||||
model = torch.hub.load(
|
||||
"monorimet/ZoeDepth:torch_update",
|
||||
"ZoeD_N",
|
||||
pretrained=False,
|
||||
force_reload=False,
|
||||
)
|
||||
model.load_state_dict(
|
||||
torch.load(modelpath, map_location=model.device)["model"]
|
||||
)
|
||||
model.eval()
|
||||
self.model = model
|
||||
|
||||
def __call__(self, input_image):
|
||||
assert input_image.ndim == 3
|
||||
image_depth = input_image
|
||||
with torch.no_grad():
|
||||
image_depth = torch.from_numpy(image_depth).float()
|
||||
image_depth = image_depth / 255.0
|
||||
image_depth = rearrange(image_depth, "h w c -> 1 c h w")
|
||||
depth = self.model.infer(image_depth)
|
||||
|
||||
depth = depth[0, 0].cpu().numpy()
|
||||
|
||||
vmin = np.percentile(depth, 2)
|
||||
vmax = np.percentile(depth, 85)
|
||||
|
||||
depth -= vmin
|
||||
depth /= vmax - vmin
|
||||
depth = 1.0 - depth
|
||||
depth_image = (depth * 255.0).clip(0, 255).astype(np.uint8)
|
||||
|
||||
return depth_image
|
||||
@@ -895,6 +895,13 @@ def save_output_img(output_img, img_seed, extra_info=None):
|
||||
pngInfo = PngImagePlugin.PngInfo()
|
||||
|
||||
if args.write_metadata_to_png:
|
||||
# Using a conditional expression caused problems, so setting a new
|
||||
# variable for now.
|
||||
if args.use_hiresfix:
|
||||
png_size_text = f"{args.hiresfix_width}x{args.hiresfix_height}"
|
||||
else:
|
||||
png_size_text = f"{args.width}x{args.height}"
|
||||
|
||||
pngInfo.add_text(
|
||||
"parameters",
|
||||
f"{args.prompts[0]}"
|
||||
@@ -903,7 +910,7 @@ def save_output_img(output_img, img_seed, extra_info=None):
|
||||
f"Sampler: {args.scheduler}, "
|
||||
f"CFG scale: {args.guidance_scale}, "
|
||||
f"Seed: {img_seed},"
|
||||
f"Size: {args.width}x{args.height}, "
|
||||
f"Size: {png_size_text}, "
|
||||
f"Model: {img_model}, "
|
||||
f"VAE: {img_vae}, "
|
||||
f"LoRA: {img_lora}",
|
||||
@@ -930,8 +937,10 @@ def save_output_img(output_img, img_seed, extra_info=None):
|
||||
"CFG_SCALE": args.guidance_scale,
|
||||
"PRECISION": args.precision,
|
||||
"STEPS": args.steps,
|
||||
"HEIGHT": args.height,
|
||||
"WIDTH": args.width,
|
||||
"HEIGHT": args.height
|
||||
if not args.use_hiresfix
|
||||
else args.hiresfix_height,
|
||||
"WIDTH": args.width if not args.use_hiresfix else args.hiresfix_width,
|
||||
"MAX_LENGTH": args.max_length,
|
||||
"OUTPUT": out_img_path,
|
||||
"VAE": img_vae,
|
||||
@@ -969,6 +978,10 @@ def get_generation_text_info(seeds, device):
|
||||
)
|
||||
text_output += (
|
||||
f"\nsize={args.height}x{args.width}, "
|
||||
if not args.use_hiresfix
|
||||
else f"\nsize={args.hiresfix_height}x{args.hiresfix_width}, "
|
||||
)
|
||||
text_output += (
|
||||
f"batch_count={args.batch_count}, "
|
||||
f"batch_size={args.batch_size}, "
|
||||
f"max_length={args.max_length}"
|
||||
|
||||
@@ -435,8 +435,13 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
|
||||
elem_id="stencil_model",
|
||||
label="Stencil model",
|
||||
value="None",
|
||||
choices=["None", "canny", "openpose", "scribble"],
|
||||
allow_custom_value=True,
|
||||
choices=[
|
||||
"None",
|
||||
"canny",
|
||||
"openpose",
|
||||
"scribble",
|
||||
"zoedepth",
|
||||
],
|
||||
)
|
||||
|
||||
def show_canvas(choice):
|
||||
@@ -638,16 +643,6 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
|
||||
choices=available_devices,
|
||||
allow_custom_value=True,
|
||||
)
|
||||
with gr.Row():
|
||||
random_seed = gr.Button("Randomize Seed")
|
||||
random_seed.click(
|
||||
lambda: -1,
|
||||
inputs=[],
|
||||
outputs=[seed],
|
||||
queue=False,
|
||||
)
|
||||
stop_batch = gr.Button("Stop Batch")
|
||||
stable_diffusion = gr.Button("Generate Image(s)")
|
||||
|
||||
with gr.Column(scale=1, min_width=600):
|
||||
with gr.Group():
|
||||
@@ -667,6 +662,18 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
|
||||
show_label=False,
|
||||
)
|
||||
img2img_status = gr.Textbox(visible=False)
|
||||
with gr.Row():
|
||||
stable_diffusion = gr.Button("Generate Image(s)")
|
||||
random_seed = gr.Button("Randomize Seed")
|
||||
random_seed.click(
|
||||
lambda: -1,
|
||||
inputs=[],
|
||||
outputs=[seed],
|
||||
queue=False,
|
||||
)
|
||||
stop_batch = gr.Button("Stop Batch")
|
||||
with gr.Row():
|
||||
blank_thing_for_row = None
|
||||
with gr.Row():
|
||||
img2img_sendto_inpaint = gr.Button(value="SendTo Inpaint")
|
||||
img2img_sendto_outpaint = gr.Button(
|
||||
|
||||
@@ -514,16 +514,6 @@ with gr.Blocks(title="Inpainting") as inpaint_web:
|
||||
choices=available_devices,
|
||||
allow_custom_value=True,
|
||||
)
|
||||
with gr.Row():
|
||||
random_seed = gr.Button("Randomize Seed")
|
||||
random_seed.click(
|
||||
lambda: -1,
|
||||
inputs=[],
|
||||
outputs=[seed],
|
||||
queue=False,
|
||||
)
|
||||
stop_batch = gr.Button("Stop Batch")
|
||||
stable_diffusion = gr.Button("Generate Image(s)")
|
||||
|
||||
with gr.Column(scale=1, min_width=600):
|
||||
with gr.Group():
|
||||
@@ -543,7 +533,18 @@ with gr.Blocks(title="Inpainting") as inpaint_web:
|
||||
show_label=False,
|
||||
)
|
||||
inpaint_status = gr.Textbox(visible=False)
|
||||
|
||||
with gr.Row():
|
||||
stable_diffusion = gr.Button("Generate Image(s)")
|
||||
random_seed = gr.Button("Randomize Seed")
|
||||
random_seed.click(
|
||||
lambda: -1,
|
||||
inputs=[],
|
||||
outputs=[seed],
|
||||
queue=False,
|
||||
)
|
||||
stop_batch = gr.Button("Stop Batch")
|
||||
with gr.Row():
|
||||
blank_thing_for_row = None
|
||||
with gr.Row():
|
||||
inpaint_sendto_img2img = gr.Button(value="SendTo Img2Img")
|
||||
inpaint_sendto_outpaint = gr.Button(
|
||||
|
||||
@@ -540,16 +540,6 @@ with gr.Blocks(title="Outpainting") as outpaint_web:
|
||||
choices=available_devices,
|
||||
allow_custom_value=True,
|
||||
)
|
||||
with gr.Row():
|
||||
random_seed = gr.Button("Randomize Seed")
|
||||
random_seed.click(
|
||||
lambda: -1,
|
||||
inputs=[],
|
||||
outputs=[seed],
|
||||
queue=False,
|
||||
)
|
||||
stop_batch = gr.Button("Stop Batch")
|
||||
stable_diffusion = gr.Button("Generate Image(s)")
|
||||
|
||||
with gr.Column(scale=1, min_width=600):
|
||||
with gr.Group():
|
||||
@@ -569,6 +559,18 @@ with gr.Blocks(title="Outpainting") as outpaint_web:
|
||||
show_label=False,
|
||||
)
|
||||
outpaint_status = gr.Textbox(visible=False)
|
||||
with gr.Row():
|
||||
stable_diffusion = gr.Button("Generate Image(s)")
|
||||
random_seed = gr.Button("Randomize Seed")
|
||||
random_seed.click(
|
||||
lambda: -1,
|
||||
inputs=[],
|
||||
outputs=[seed],
|
||||
queue=False,
|
||||
)
|
||||
stop_batch = gr.Button("Stop Batch")
|
||||
with gr.Row():
|
||||
blank_thing_for_row = None
|
||||
with gr.Row():
|
||||
outpaint_sendto_img2img = gr.Button(value="SendTo Img2Img")
|
||||
outpaint_sendto_inpaint = gr.Button(value="SendTo Inpaint")
|
||||
|
||||
@@ -32,36 +32,39 @@ model_map = {
|
||||
# NOTE: Each `model_name` should have its own start message
|
||||
start_message = {
|
||||
"llama2_7b": (
|
||||
"System: You are a helpful, respectful and honest assistant. Always answer "
|
||||
"as helpfully as possible, while being safe. Your answers should not "
|
||||
"include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal "
|
||||
"content. Please ensure that your responses are socially unbiased and positive "
|
||||
"in nature. If a question does not make any sense, or is not factually coherent, "
|
||||
"explain why instead of answering something not correct. If you don't know the "
|
||||
"answer to a question, please don't share false information."
|
||||
"You are a helpful, respectful and honest assistant. Always answer "
|
||||
"as helpfully as possible, while being safe. Your answers should not "
|
||||
"include any harmful, unethical, racist, sexist, toxic, dangerous, or "
|
||||
"illegal content. Please ensure that your responses are socially "
|
||||
"unbiased and positive in nature. If a question does not make any "
|
||||
"sense, or is not factually coherent, explain why instead of "
|
||||
"answering something not correct. If you don't know the answer "
|
||||
"to a question, please don't share false information."
|
||||
),
|
||||
"llama2_13b": (
|
||||
"System: You are a helpful, respectful and honest assistant. Always answer "
|
||||
"as helpfully as possible, while being safe. Your answers should not "
|
||||
"include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal "
|
||||
"content. Please ensure that your responses are socially unbiased and positive "
|
||||
"in nature. If a question does not make any sense, or is not factually coherent, "
|
||||
"explain why instead of answering something not correct. If you don't know the "
|
||||
"answer to a question, please don't share false information."
|
||||
"You are a helpful, respectful and honest assistant. Always answer "
|
||||
"as helpfully as possible, while being safe. Your answers should not "
|
||||
"include any harmful, unethical, racist, sexist, toxic, dangerous, or "
|
||||
"illegal content. Please ensure that your responses are socially "
|
||||
"unbiased and positive in nature. If a question does not make any "
|
||||
"sense, or is not factually coherent, explain why instead of "
|
||||
"answering something not correct. If you don't know the answer "
|
||||
"to a question, please don't share false information."
|
||||
),
|
||||
"llama2_70b": (
|
||||
"System: You are a helpful, respectful and honest assistant. Always answer "
|
||||
"as helpfully as possible, while being safe. Your answers should not "
|
||||
"include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal "
|
||||
"content. Please ensure that your responses are socially unbiased and positive "
|
||||
"in nature. If a question does not make any sense, or is not factually coherent, "
|
||||
"explain why instead of answering something not correct. If you don't know the "
|
||||
"answer to a question, please don't share false information."
|
||||
"You are a helpful, respectful and honest assistant. Always answer "
|
||||
"as helpfully as possible, while being safe. Your answers should not "
|
||||
"include any harmful, unethical, racist, sexist, toxic, dangerous, or "
|
||||
"illegal content. Please ensure that your responses are socially "
|
||||
"unbiased and positive in nature. If a question does not make any "
|
||||
"sense, or is not factually coherent, explain why instead of "
|
||||
"answering something not correct. If you don't know the answer "
|
||||
"to a question, please don't share false information."
|
||||
),
|
||||
"vicuna": (
|
||||
"A chat between a curious user and an artificial intelligence assistant. "
|
||||
"The assistant gives helpful, detailed, and polite answers to the user's "
|
||||
"questions.\n"
|
||||
"A chat between a curious user and an artificial intelligence "
|
||||
"assistant. The assistant gives helpful, detailed, and "
|
||||
"polite answers to the user's questions.\n"
|
||||
),
|
||||
}
|
||||
|
||||
@@ -77,7 +80,10 @@ def create_prompt(model_name, history, prompt_prefix):
|
||||
conversation = "".join(
|
||||
[f"{B_INST} {item[0]} {E_INST} {item[1]} " for item in history[1:]]
|
||||
)
|
||||
msg = f"{B_INST} {B_SYS} {system_message} {E_SYS} {history[0][0]} {E_INST} {history[0][1]} {conversation}"
|
||||
if prompt_prefix:
|
||||
msg = f"{B_INST} {B_SYS}{system_message}{E_SYS}{history[0][0]} {E_INST} {history[0][1]} {conversation}"
|
||||
else:
|
||||
msg = f"{B_INST} {history[0][0]} {E_INST} {history[0][1]} {conversation}"
|
||||
elif model_name in ["vicuna"]:
|
||||
conversation = "".join(
|
||||
[
|
||||
@@ -210,8 +216,14 @@ def chat(
|
||||
assert (
|
||||
device_id
|
||||
), f"no vulkan hardware for target-triple '{vulkan_target_triple}' exists"
|
||||
print(f"Will use vulkan target triple : {vulkan_target_triple}")
|
||||
|
||||
print(f"Will use target triple : {vulkan_target_triple}")
|
||||
elif "rocm" in device:
|
||||
# add iree rocm flags
|
||||
_extra_args.append(
|
||||
f"--iree-rocm-target-chip={args.iree_rocm_target_chip}"
|
||||
)
|
||||
print(f"extra args = {_extra_args}")
|
||||
|
||||
if model_name == "vicuna4":
|
||||
vicuna_model = ShardedVicuna(
|
||||
|
||||
@@ -140,6 +140,11 @@ def txt2img_inf(
|
||||
args.max_length = max_length
|
||||
args.height = height
|
||||
args.width = width
|
||||
args.use_hiresfix = use_hiresfix
|
||||
args.hiresfix_height = hiresfix_height
|
||||
args.hiresfix_width = hiresfix_width
|
||||
args.hiresfix_strength = hiresfix_strength
|
||||
args.resample_type = resample_type
|
||||
args.device = device.split("=>", 1)[1].strip()
|
||||
args.iree_vulkan_target_triple = init_iree_vulkan_target_triple
|
||||
args.iree_metal_target_platform = init_iree_metal_target_platform
|
||||
|
||||
@@ -533,16 +533,6 @@ with gr.Blocks(title="Upscaler") as upscaler_web:
|
||||
choices=available_devices,
|
||||
allow_custom_value=True,
|
||||
)
|
||||
with gr.Row():
|
||||
random_seed = gr.Button("Randomize Seed")
|
||||
random_seed.click(
|
||||
lambda: -1,
|
||||
inputs=[],
|
||||
outputs=[seed],
|
||||
queue=False,
|
||||
)
|
||||
stop_batch = gr.Button("Stop Batch")
|
||||
stable_diffusion = gr.Button("Generate Image(s)")
|
||||
|
||||
with gr.Column(scale=1, min_width=600):
|
||||
with gr.Group():
|
||||
@@ -562,7 +552,18 @@ with gr.Blocks(title="Upscaler") as upscaler_web:
|
||||
show_label=False,
|
||||
)
|
||||
upscaler_status = gr.Textbox(visible=False)
|
||||
|
||||
with gr.Row():
|
||||
stable_diffusion = gr.Button("Generate Image(s)")
|
||||
random_seed = gr.Button("Randomize Seed")
|
||||
random_seed.click(
|
||||
lambda: -1,
|
||||
inputs=[],
|
||||
outputs=[seed],
|
||||
queue=False,
|
||||
)
|
||||
stop_batch = gr.Button("Stop Batch")
|
||||
with gr.Row():
|
||||
blank_thing_for_row = None
|
||||
with gr.Row():
|
||||
upscaler_sendto_img2img = gr.Button(value="SendTo Img2Img")
|
||||
upscaler_sendto_inpaint = gr.Button(value="SendTo Inpaint")
|
||||
|
||||
@@ -129,12 +129,12 @@ pytest_benchmark_param = pytest.mark.parametrize(
|
||||
pytest.param(True, "cpu", marks=pytest.mark.skip),
|
||||
pytest.param(
|
||||
False,
|
||||
"gpu",
|
||||
"cuda",
|
||||
marks=pytest.mark.skipif(
|
||||
check_device_drivers("gpu"), reason="nvidia-smi not found"
|
||||
check_device_drivers("cuda"), reason="nvidia-smi not found"
|
||||
),
|
||||
),
|
||||
pytest.param(True, "gpu", marks=pytest.mark.skip),
|
||||
pytest.param(True, "cuda", marks=pytest.mark.skip),
|
||||
pytest.param(
|
||||
False,
|
||||
"vulkan",
|
||||
|
||||
@@ -41,6 +41,7 @@ tiktoken # for codegen
|
||||
joblib # for langchain
|
||||
timm # for MiniGPT4
|
||||
langchain
|
||||
einops # for zoedepth
|
||||
|
||||
# Keep PyInstaller at the end. Sometimes Windows Defender flags it but most folks can continue even if it errors
|
||||
pefile
|
||||
|
||||
@@ -95,38 +95,31 @@ _IREE_TARGET_MAP = {
|
||||
# Finds whether the required drivers are installed for the given device.
|
||||
@functools.cache
|
||||
def check_device_drivers(device):
|
||||
"""Checks necessary drivers present for gpu and vulkan devices"""
|
||||
"""
|
||||
Checks necessary drivers present for gpu and vulkan devices
|
||||
False => drivers present!
|
||||
"""
|
||||
if "://" in device:
|
||||
device = device.split("://")[0]
|
||||
|
||||
if device == "cuda":
|
||||
try:
|
||||
subprocess.check_output("nvidia-smi")
|
||||
except Exception:
|
||||
return True
|
||||
elif device in ["vulkan"]:
|
||||
try:
|
||||
subprocess.check_output("vulkaninfo")
|
||||
except Exception:
|
||||
return True
|
||||
elif device == "metal":
|
||||
return False
|
||||
elif device in ["intel-gpu"]:
|
||||
try:
|
||||
subprocess.check_output(["dpkg", "-L", "intel-level-zero-gpu"])
|
||||
return False
|
||||
except Exception:
|
||||
return True
|
||||
elif device == "cpu":
|
||||
return False
|
||||
elif device == "rocm":
|
||||
try:
|
||||
if sys.platform == "win32":
|
||||
subprocess.check_output("hipinfo")
|
||||
else:
|
||||
subprocess.check_output("rocminfo")
|
||||
except Exception:
|
||||
return True
|
||||
from iree.runtime import get_driver
|
||||
|
||||
device_mapped = iree_device_map(device)
|
||||
|
||||
try:
|
||||
_ = get_driver(device_mapped)
|
||||
except ValueError as ve:
|
||||
print(
|
||||
f"[ERR] device `{device}` not registered with IREE. "
|
||||
"Ensure IREE is configured for use with this device.\n"
|
||||
f"Full Error: \n {repr(ve)}"
|
||||
)
|
||||
return True
|
||||
except RuntimeError as re:
|
||||
print(
|
||||
f"[ERR] Failed to get driver for {device} with error:\n{repr(re)}"
|
||||
)
|
||||
return True
|
||||
|
||||
# Unknown device. We assume drivers are installed.
|
||||
return False
|
||||
@@ -134,11 +127,32 @@ def check_device_drivers(device):
|
||||
|
||||
# Installation info for the missing device drivers.
|
||||
def device_driver_info(device):
|
||||
if device == "cuda":
|
||||
return "nvidia-smi not found, please install the required drivers from https://www.nvidia.in/Download/index.aspx?lang=en-in"
|
||||
elif device in ["metal", "vulkan"]:
|
||||
return "vulkaninfo not found, Install from https://vulkan.lunarg.com/sdk/home or your distribution"
|
||||
elif device == "rocm":
|
||||
return "rocm info not found. Please install rocm"
|
||||
device_driver_err_map = {
|
||||
"cuda": {
|
||||
"debug": "Try `nvidia-smi` on system to check.",
|
||||
"solution": " from https://www.nvidia.in/Download/index.aspx?lang=en-in for your system.",
|
||||
},
|
||||
"vulkan": {
|
||||
"debug": "Try `vulkaninfo` on system to check.",
|
||||
"solution": " from https://vulkan.lunarg.com/sdk/home for your distribution.",
|
||||
},
|
||||
"metal": {
|
||||
"debug": "Check if Bare metal is supported and enabled on your system.",
|
||||
"solution": ".",
|
||||
},
|
||||
"rocm": {
|
||||
"debug": f"Try `{'hip' if sys.platform == 'win32' else 'rocm'}info` on system to check.",
|
||||
"solution": " from https://rocm.docs.amd.com/en/latest/rocm.html for your system.",
|
||||
},
|
||||
}
|
||||
|
||||
if device in device_driver_err_map:
|
||||
err_msg = (
|
||||
f"Required drivers for {device} not found. {device_driver_err_map[device]['debug']} "
|
||||
f"Please install the required drivers{device_driver_err_map[device]['solution']} "
|
||||
f"For further assistance please reach out to the community on discord [https://discord.com/invite/RUqY2h2s9u]"
|
||||
f" and/or file a bug at https://github.com/nod-ai/SHARK/issues"
|
||||
)
|
||||
return err_msg
|
||||
else:
|
||||
return f"{device} is not supported."
|
||||
|
||||
@@ -75,7 +75,7 @@ def get_iree_device_args(device, extra_args=[]):
|
||||
if device_uri[0] == "rocm":
|
||||
from shark.iree_utils.gpu_utils import get_iree_rocm_args
|
||||
|
||||
return get_iree_rocm_args()
|
||||
return get_iree_rocm_args(extra_args=extra_args)
|
||||
return []
|
||||
|
||||
|
||||
@@ -392,6 +392,9 @@ def load_vmfb_using_mmap(
|
||||
)
|
||||
dl.log(f"ireert.create_device()")
|
||||
config = ireert.Config(device=haldevice)
|
||||
config.id = haldriver.query_available_devices()[device_idx][
|
||||
"device_id"
|
||||
]
|
||||
dl.log(f"ireert.Config()")
|
||||
else:
|
||||
config = get_iree_runtime_config(device)
|
||||
@@ -574,10 +577,17 @@ def get_results(
|
||||
frontend="torch",
|
||||
send_to_host=True,
|
||||
debug_timeout: float = 5.0,
|
||||
device: str = None,
|
||||
):
|
||||
"""Runs a .vmfb file given inputs and config and returns output."""
|
||||
with DetailLogger(debug_timeout) as dl:
|
||||
device_inputs = []
|
||||
if device == "rocm":
|
||||
haldriver = ireert.get_driver("rocm")
|
||||
haldevice = haldriver.create_device(
|
||||
config.id,
|
||||
allocators=shark_args.device_allocator,
|
||||
)
|
||||
for input_array in input:
|
||||
dl.log(f"Load to device: {input_array.shape}")
|
||||
device_inputs.append(
|
||||
|
||||
@@ -40,55 +40,26 @@ def get_iree_gpu_args():
|
||||
|
||||
|
||||
# Get the default gpu args given the architecture.
|
||||
@functools.cache
|
||||
def get_iree_rocm_args():
|
||||
def get_iree_rocm_args(extra_args=[]):
|
||||
ireert.flags.FUNCTION_INPUT_VALIDATION = False
|
||||
# get arch from hipinfo.
|
||||
import os
|
||||
import re
|
||||
import subprocess
|
||||
rocm_flags = ["--iree-rocm-link-bc=true"]
|
||||
|
||||
if sys.platform == "win32":
|
||||
if "HIP_PATH" in os.environ:
|
||||
rocm_path = os.environ["HIP_PATH"]
|
||||
print(f"Found a ROCm installation at {rocm_path}.")
|
||||
else:
|
||||
print("Failed to find ROCM_PATH. Defaulting to C:\\AMD\\ROCM\\5.5")
|
||||
rocm_path = "C:\\AMD\\ROCM\\5.5"
|
||||
else:
|
||||
if "ROCM_PATH" in os.environ:
|
||||
rocm_path = os.environ["ROCM_PATH"]
|
||||
print(f"Found a ROCm installation at {rocm_path}.")
|
||||
else:
|
||||
print("Failed to find ROCM_PATH. Defaulting to /opt/rocm")
|
||||
rocm_path = "/opt/rocm/"
|
||||
|
||||
try:
|
||||
if sys.platform == "win32":
|
||||
rocm_arch = re.search(
|
||||
r"gfx\d{3,}",
|
||||
subprocess.check_output("hipinfo", shell=True, text=True),
|
||||
).group(0)
|
||||
else:
|
||||
rocm_arch = re.match(
|
||||
r".*(gfx\w+)",
|
||||
subprocess.check_output(
|
||||
"rocminfo | grep -i 'gfx'", shell=True, text=True
|
||||
),
|
||||
).group(1)
|
||||
print(f"Found rocm arch {rocm_arch}...")
|
||||
except:
|
||||
# Add the target arch flag for rocm device
|
||||
flag_present = False
|
||||
for flag in extra_args:
|
||||
if "iree-rocm-target-chip" in flag:
|
||||
flag_present = True
|
||||
print(
|
||||
f"found rocm target device arch from flag : {flag.split('=')[1]}"
|
||||
)
|
||||
if not flag_present:
|
||||
print(
|
||||
"Failed to find ROCm architecture from hipinfo / rocminfo. Defaulting to gfx1100."
|
||||
)
|
||||
rocm_arch = "gfx1100"
|
||||
rocm_flags.append(f"--iree-rocm-target-chip={rocm_arch}")
|
||||
|
||||
bc_path = os.path.join(rocm_path, "amdgcn", "bitcode")
|
||||
return [
|
||||
f"--iree-rocm-target-chip={rocm_arch}",
|
||||
"--iree-rocm-link-bc=true",
|
||||
f"--iree-rocm-bc-dir={bc_path}",
|
||||
]
|
||||
return rocm_flags
|
||||
|
||||
|
||||
# Some constants taken from cuda.h
|
||||
|
||||
@@ -68,6 +68,8 @@ def get_vulkan_target_triple(device_name):
|
||||
Returns:
|
||||
str or None: target triple or None if no match found for given name
|
||||
"""
|
||||
|
||||
# TODO: Replace this with a dict or something smarter.
|
||||
system_os = get_os_name()
|
||||
# Apple Targets
|
||||
if all(x in device_name for x in ("Apple", "M1")):
|
||||
@@ -117,6 +119,8 @@ def get_vulkan_target_triple(device_name):
|
||||
# Amd Targets
|
||||
# Linux: Radeon RX 7900 XTX
|
||||
# Windows: AMD Radeon RX 7900 XTX
|
||||
elif all(x in device_name for x in ("RX", "7800")):
|
||||
triple = f"rdna3-7800-{system_os}"
|
||||
elif all(x in device_name for x in ("RX", "7900")):
|
||||
triple = f"rdna3-7900-{system_os}"
|
||||
elif all(x in device_name for x in ("Radeon", "780M")):
|
||||
|
||||
@@ -150,11 +150,15 @@ class SharkInference:
|
||||
|
||||
# inputs are considered to be tuple of np.array.
|
||||
def __call__(self, function_name: str, inputs: tuple, send_to_host=True):
|
||||
return self.shark_runner.run(function_name, inputs, send_to_host)
|
||||
return self.shark_runner.run(
|
||||
function_name, inputs, send_to_host, device=self.device
|
||||
)
|
||||
|
||||
# forward function.
|
||||
def forward(self, inputs: tuple, send_to_host=True):
|
||||
return self.shark_runner.run("forward", inputs, send_to_host)
|
||||
return self.shark_runner.run(
|
||||
"forward", inputs, send_to_host, device=self.device
|
||||
)
|
||||
|
||||
# Get all function names defined within the compiled module.
|
||||
def get_functions_in_module(self):
|
||||
|
||||
@@ -109,7 +109,9 @@ class SharkRunner:
|
||||
self.temp_file_to_unlink = params["temp_file_to_unlink"]
|
||||
del params
|
||||
|
||||
def run(self, function_name, inputs: tuple, send_to_host=False):
|
||||
def run(
|
||||
self, function_name, inputs: tuple, send_to_host=False, device=None
|
||||
):
|
||||
return get_results(
|
||||
self.iree_compilation_module,
|
||||
function_name,
|
||||
@@ -117,6 +119,7 @@ class SharkRunner:
|
||||
self.iree_config,
|
||||
self.mlir_dialect,
|
||||
send_to_host,
|
||||
device=device,
|
||||
)
|
||||
|
||||
# Get all function names defined within the compiled module.
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
from shark.iree_utils._common import check_device_drivers, device_driver_info
|
||||
from shark.shark_inference import SharkInference
|
||||
from shark.shark_downloader import download_model
|
||||
from shark.parser import shark_args
|
||||
|
||||
@@ -1,30 +1,25 @@
|
||||
import argparse
|
||||
import os
|
||||
import torch
|
||||
import numpy as np
|
||||
from shark_opt_wrapper import OPTForCausalLMModel
|
||||
from shark.iree_utils._common import (
|
||||
check_device_drivers,
|
||||
device_driver_info,
|
||||
)
|
||||
from shark.shark_inference import SharkInference
|
||||
from shark.shark_importer import import_with_fx
|
||||
from transformers import AutoTokenizer, OPTForCausalLM
|
||||
|
||||
OPT_MODEL = "opt-1.3b"
|
||||
OPT_FS_NAME = "opt-1_3b"
|
||||
MAX_SEQUENCE_LENGTH = 128
|
||||
MAX_NEW_TOKENS = 60
|
||||
from typing import Iterable
|
||||
|
||||
|
||||
def create_module(model_name, tokenizer, device):
|
||||
opt_base_model = OPTForCausalLM.from_pretrained("facebook/" + model_name)
|
||||
def create_module(model_name, tokenizer, device, args):
|
||||
opt_base_model = OPTForCausalLM.from_pretrained(
|
||||
model_name, allow_mismatched_sizes=True
|
||||
)
|
||||
opt_base_model.eval()
|
||||
opt_model = OPTForCausalLMModel(opt_base_model)
|
||||
encoded_inputs = tokenizer(
|
||||
"What is the meaning of life?",
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
max_length=MAX_SEQUENCE_LENGTH,
|
||||
max_length=args.max_seq_len,
|
||||
return_tensors="pt",
|
||||
)
|
||||
inputs = (
|
||||
@@ -33,8 +28,11 @@ def create_module(model_name, tokenizer, device):
|
||||
)
|
||||
# np.save("model_inputs_0.npy", inputs[0])
|
||||
# np.save("model_inputs_1.npy", inputs[1])
|
||||
opt_fs_name = "-".join(
|
||||
"_".join(args.model_name.split("/")[1].split("-")).split(".")
|
||||
)
|
||||
|
||||
mlir_path = f"./{OPT_FS_NAME}_causallm_{MAX_SEQUENCE_LENGTH}_torch.mlir"
|
||||
mlir_path = f"./{opt_fs_name}_causallm_{args.max_seq_len}_torch.mlir"
|
||||
if os.path.isfile(mlir_path):
|
||||
print(f"Found .mlir from {mlir_path}")
|
||||
else:
|
||||
@@ -42,7 +40,7 @@ def create_module(model_name, tokenizer, device):
|
||||
model=opt_model,
|
||||
inputs=inputs,
|
||||
is_f16=False,
|
||||
model_name=OPT_FS_NAME,
|
||||
model_name=opt_fs_name,
|
||||
return_str=True,
|
||||
)
|
||||
with open(mlir_path, "w") as f:
|
||||
@@ -57,7 +55,7 @@ def create_module(model_name, tokenizer, device):
|
||||
is_benchmark=False,
|
||||
)
|
||||
|
||||
vmfb_name = f"{OPT_FS_NAME}_causallm_{MAX_SEQUENCE_LENGTH}_torch_{device}"
|
||||
vmfb_name = f"{opt_fs_name}_causallm_{args.max_seq_len}_torch_cpu"
|
||||
shark_module.save_module(module_name=vmfb_name, debug=False)
|
||||
vmfb_path = vmfb_name + ".vmfb"
|
||||
return vmfb_path
|
||||
@@ -71,11 +69,11 @@ def shouldStop(tokens):
|
||||
return False
|
||||
|
||||
|
||||
def generate_new_token(shark_model, tokenizer, new_text):
|
||||
def generate_new_token(shark_module, tokenizer, new_text, max_seq_len: int):
|
||||
model_inputs = tokenizer(
|
||||
new_text,
|
||||
padding="max_length",
|
||||
max_length=MAX_SEQUENCE_LENGTH,
|
||||
max_length=max_seq_len,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
@@ -84,7 +82,7 @@ def generate_new_token(shark_model, tokenizer, new_text):
|
||||
model_inputs["attention_mask"],
|
||||
)
|
||||
sum_attentionmask = torch.sum(model_inputs.attention_mask)
|
||||
output = shark_model("forward", inputs)
|
||||
output = shark_module("forward", inputs)
|
||||
output = torch.FloatTensor(output[0])
|
||||
next_toks = torch.topk(output, 1)
|
||||
stop_generation = False
|
||||
@@ -104,39 +102,96 @@ def generate_new_token(shark_model, tokenizer, new_text):
|
||||
return ret_dict
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--max-seq-len", type=int, default=32)
|
||||
parser.add_argument(
|
||||
"--model-name",
|
||||
help="Model name",
|
||||
type=str,
|
||||
choices=[
|
||||
"facebook/opt-125m",
|
||||
"facebook/opt-350m",
|
||||
"facebook/opt-1.3b",
|
||||
"facebook/opt-6.7b",
|
||||
"mit-han-lab/opt-125m-smoothquant",
|
||||
"mit-han-lab/opt-1.3b-smoothquant",
|
||||
"mit-han-lab/opt-2.7b-smoothquant",
|
||||
"mit-han-lab/opt-6.7b-smoothquant",
|
||||
"mit-han-lab/opt-13b-smoothquant",
|
||||
],
|
||||
default="facebook/opt-1.3b",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--recompile",
|
||||
help="If set, recompiles MLIR -> .vmfb",
|
||||
action=argparse.BooleanOptionalAction,
|
||||
default=False,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--plugin-path",
|
||||
help="path to executable plugin",
|
||||
type=str,
|
||||
default=None,
|
||||
)
|
||||
args = parser.parse_args()
|
||||
print("args={}".format(args))
|
||||
return args
|
||||
|
||||
|
||||
def generate_tokens(
|
||||
opt_shark_module: "SharkInference",
|
||||
tokenizer,
|
||||
input_text: str,
|
||||
max_output_len: int,
|
||||
print_intermediate_results: True,
|
||||
) -> Iterable[str]:
|
||||
words_list = []
|
||||
new_text = input_text
|
||||
try:
|
||||
for _ in range(max_output_len):
|
||||
generated_token_op = generate_new_token(
|
||||
opt_shark_module, tokenizer, new_text, max_output_len
|
||||
)
|
||||
detok = generated_token_op["detok"]
|
||||
if generated_token_op["stop_generation"]:
|
||||
break
|
||||
if print_intermediate_results:
|
||||
print(detok, end="", flush=True)
|
||||
words_list.append(detok)
|
||||
if detok == "":
|
||||
break
|
||||
new_text += detok
|
||||
except KeyboardInterrupt as e:
|
||||
print("Exiting token generation.")
|
||||
return words_list
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
"facebook/" + OPT_MODEL, use_fast=False
|
||||
args = parse_args()
|
||||
if "smoothquant" in args.model_name:
|
||||
token_model_name = f"facebook/opt-{args.model_name.split('-')[3]}"
|
||||
else:
|
||||
token_model_name = args.model_name
|
||||
tokenizer = AutoTokenizer.from_pretrained(token_model_name, use_fast=False)
|
||||
opt_fs_name = "-".join(
|
||||
"_".join(args.model_name.split("/")[1].split("-")).split(".")
|
||||
)
|
||||
vmfb_path = (
|
||||
f"./{OPT_FS_NAME}_causallm_{MAX_SEQUENCE_LENGTH}_torch_cpu-task.vmfb"
|
||||
vmfb_path = f"./{opt_fs_name}_causallm_{args.max_seq_len}_torch_cpu.vmfb"
|
||||
if args.plugin_path is not None:
|
||||
rt_flags = [f"--executable_plugin={args.plugin_path}"]
|
||||
else:
|
||||
rt_flags = []
|
||||
opt_shark_module = SharkInference(
|
||||
mlir_module=None, device="cpu-task", rt_flags=rt_flags
|
||||
)
|
||||
opt_shark_module = SharkInference(mlir_module=None, device="cpu-task")
|
||||
if os.path.isfile(vmfb_path):
|
||||
opt_shark_module.load_module(vmfb_path)
|
||||
else:
|
||||
vmfb_path = create_module(OPT_MODEL, tokenizer, "cpu-task")
|
||||
vmfb_path = create_module(args.model_name, tokenizer, "cpu-task", args)
|
||||
opt_shark_module.load_module(vmfb_path)
|
||||
while True:
|
||||
try:
|
||||
new_text = input("Give me a sentence to complete:")
|
||||
new_text_init = new_text
|
||||
words_list = []
|
||||
|
||||
for i in range(MAX_NEW_TOKENS):
|
||||
generated_token_op = generate_new_token(
|
||||
opt_shark_module, tokenizer, new_text
|
||||
)
|
||||
detok = generated_token_op["detok"]
|
||||
stop_generation = generated_token_op["stop_generation"]
|
||||
if stop_generation:
|
||||
break
|
||||
print(detok, end="", flush=True)
|
||||
words_list.append(detok)
|
||||
if detok == "":
|
||||
break
|
||||
new_text = new_text + detok
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("Exiting program.")
|
||||
break
|
||||
input_text = input("Give me a sentence to complete:")
|
||||
generate_tokens(
|
||||
opt_shark_module, tokenizer, input_text, args.max_seq_len
|
||||
)
|
||||
|
||||
74
tank/examples/opt/opt_causallm_samples.py
Normal file
74
tank/examples/opt/opt_causallm_samples.py
Normal file
@@ -0,0 +1,74 @@
|
||||
import argparse
|
||||
import os
|
||||
|
||||
import opt_causallm
|
||||
import opt_util
|
||||
|
||||
from shark.shark_inference import SharkInference
|
||||
from transformers import AutoTokenizer, OPTForCausalLM
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--max-seq-len", type=int, default=32)
|
||||
parser.add_argument(
|
||||
"--model-name",
|
||||
help="Model name",
|
||||
type=str,
|
||||
choices=[
|
||||
"facebook/opt-125m",
|
||||
"facebook/opt-350m",
|
||||
"facebook/opt-1.3b",
|
||||
"facebook/opt-6.7b",
|
||||
],
|
||||
default="facebook/opt-1.3b",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--recompile",
|
||||
help="If set, recompiles MLIR -> .vmfb",
|
||||
action=argparse.BooleanOptionalAction,
|
||||
default=False,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--plugin-path",
|
||||
help="path to executable plugin",
|
||||
type=str,
|
||||
default=None,
|
||||
)
|
||||
args = parser.parse_args()
|
||||
print("args={}".format(args))
|
||||
return args
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.model_name, use_fast=False)
|
||||
opt_fs_name = "-".join(
|
||||
"_".join(args.model_name.split("/")[1].split("-")).split(".")
|
||||
)
|
||||
vmfb_path = f"./{opt_fs_name}_causallm_{args.max_seq_len}_torch_cpu.vmfb"
|
||||
if args.plugin_path is not None:
|
||||
rt_flags = [f"--executable_plugin={args.plugin_path}"]
|
||||
else:
|
||||
rt_flags = []
|
||||
opt_shark_module = SharkInference(
|
||||
mlir_module=None, device="cpu-task", rt_flags=rt_flags
|
||||
)
|
||||
if os.path.isfile(vmfb_path):
|
||||
opt_shark_module.load_module(vmfb_path)
|
||||
else:
|
||||
vmfb_path = opt_causallm.create_module(
|
||||
args.model_name, tokenizer, "cpu-task", args
|
||||
)
|
||||
opt_shark_module.load_module(vmfb_path)
|
||||
|
||||
for prompt in opt_util.PROMPTS:
|
||||
print("\n\nprompt: {}".format(prompt))
|
||||
response = opt_causallm.generate_tokens(
|
||||
opt_shark_module,
|
||||
tokenizer,
|
||||
prompt,
|
||||
args.max_seq_len,
|
||||
print_intermediate_results=False,
|
||||
)
|
||||
print("reponse: {}".format("".join(response)))
|
||||
@@ -22,6 +22,7 @@ import time
|
||||
import numpy as np
|
||||
from typing import Tuple
|
||||
|
||||
from opt_util import PROMPTS
|
||||
from shark.shark_inference import SharkInference
|
||||
from shark.shark_importer import import_with_fx
|
||||
from transformers import AutoTokenizer, OPTForCausalLM
|
||||
@@ -44,19 +45,6 @@ REPORT_LOAD_VIRTUAL_MEMORY_MB = "load_virtual_MB"
|
||||
REPORT_RUN_PHYSICAL_MEMORY_MB = "run_physical_MB"
|
||||
REPORT_RUN_VIRTUAL_MEMORY_MB = "run_virtual_MB"
|
||||
|
||||
PROMPTS = [
|
||||
"What is the meaning of life?",
|
||||
"Tell me something you don't know.",
|
||||
"What does Xilinx do?",
|
||||
"What is the mass of earth?",
|
||||
"What is a poem?",
|
||||
"What is recursion?",
|
||||
"Tell me a one line joke.",
|
||||
"Who is Gilgamesh?",
|
||||
"Tell me something about cryptocurrency.",
|
||||
"How did it all begin?",
|
||||
]
|
||||
|
||||
ModelWrapper = collections.namedtuple("ModelWrapper", ["model", "tokenizer"])
|
||||
|
||||
|
||||
@@ -72,7 +60,9 @@ def import_mlir_module(
|
||||
device: str,
|
||||
max_seq_len: int,
|
||||
):
|
||||
opt_base_model = OPTForCausalLM.from_pretrained(model_name)
|
||||
opt_base_model = OPTForCausalLM.from_pretrained(
|
||||
model_name, ignore_mismatched_sizes=True
|
||||
)
|
||||
opt_base_model.eval()
|
||||
opt_model = OPTForCausalLMModel(opt_base_model)
|
||||
encoded_inputs = tokenizer(
|
||||
@@ -142,13 +132,14 @@ def create_vmfb_module(
|
||||
|
||||
def load_shark_model(
|
||||
model_name: str,
|
||||
token_model_name: str,
|
||||
max_seq_len: int,
|
||||
recompile_shark: bool,
|
||||
plugin_path: str = [],
|
||||
) -> ModelWrapper:
|
||||
opt_fs_name = get_opt_fs_name(model_name)
|
||||
vmfb_name = f"{opt_fs_name}_causallm_{max_seq_len}_torch_{DEVICE}_tiled_ukernels.vmfb"
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)
|
||||
vmfb_name = f"{opt_fs_name}_causallm_{max_seq_len}_torch_{DEVICE}.vmfb"
|
||||
tokenizer = AutoTokenizer.from_pretrained(token_model_name, use_fast=False)
|
||||
if recompile_shark or not os.path.isfile(vmfb_name):
|
||||
print(f"vmfb not found. compiling and saving to {vmfb_name}")
|
||||
create_vmfb_module(
|
||||
@@ -170,10 +161,12 @@ def run_shark_model(model_wrapper: ModelWrapper, tokens):
|
||||
return model_wrapper.model("forward", tokens)
|
||||
|
||||
|
||||
def load_huggingface_model(model_name: str) -> ModelWrapper:
|
||||
def load_huggingface_model(
|
||||
model_name: str, token_model_name: str
|
||||
) -> ModelWrapper:
|
||||
return ModelWrapper(
|
||||
model=OPTForCausalLM.from_pretrained(model_name),
|
||||
tokenizer=AutoTokenizer.from_pretrained(model_name),
|
||||
tokenizer=AutoTokenizer.from_pretrained(token_model_name),
|
||||
)
|
||||
|
||||
|
||||
@@ -189,11 +182,14 @@ def save_json(data, filename):
|
||||
|
||||
|
||||
def collect_huggingface_logits(
|
||||
model_name: str, max_seq_len: int, to_save_json: bool
|
||||
model_name: str,
|
||||
token_model_name: str,
|
||||
max_seq_len: int,
|
||||
to_save_json: bool,
|
||||
) -> Tuple[float, float]:
|
||||
# Load
|
||||
t0 = time.time()
|
||||
model_wrapper = load_huggingface_model(model_name)
|
||||
model_wrapper = load_huggingface_model(model_name, token_model_name)
|
||||
load_time = time.time() - t0
|
||||
print("--- Took {} seconds to load Huggingface.".format(load_time))
|
||||
load_memory_info = get_memory_info()
|
||||
@@ -237,6 +233,7 @@ def collect_huggingface_logits(
|
||||
|
||||
def collect_shark_logits(
|
||||
model_name: str,
|
||||
token_model_name: str,
|
||||
max_seq_len: int,
|
||||
recompile_shark: bool,
|
||||
to_save_json: bool,
|
||||
@@ -245,7 +242,7 @@ def collect_shark_logits(
|
||||
# Load
|
||||
t0 = time.time()
|
||||
model_wrapper = load_shark_model(
|
||||
model_name, max_seq_len, recompile_shark, plugin_path
|
||||
model_name, token_model_name, max_seq_len, recompile_shark, plugin_path
|
||||
)
|
||||
load_time = time.time() - t0
|
||||
print("--- Took {} seconds to load Shark.".format(load_time))
|
||||
@@ -327,6 +324,11 @@ def parse_args():
|
||||
"facebook/opt-350m",
|
||||
"facebook/opt-1.3b",
|
||||
"facebook/opt-6.7b",
|
||||
"mit-han-lab/opt-125m-smoothquant",
|
||||
"mit-han-lab/opt-1.3b-smoothquant",
|
||||
"mit-han-lab/opt-2.7b-smoothquant",
|
||||
"mit-han-lab/opt-6.7b-smoothquant",
|
||||
"mit-han-lab/opt-13b-smoothquant",
|
||||
],
|
||||
default="facebook/opt-1.3b",
|
||||
)
|
||||
@@ -344,11 +346,17 @@ def parse_args():
|
||||
default=PLATFORM_SHARK,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--plugin_path",
|
||||
"--plugin-path",
|
||||
help="path to executable plugin",
|
||||
type=str,
|
||||
default=None,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--token-model-name",
|
||||
help="HF ID to create tokenizer.",
|
||||
type=str,
|
||||
default=None,
|
||||
)
|
||||
args = parser.parse_args()
|
||||
print("args={}".format(args))
|
||||
return args
|
||||
@@ -356,9 +364,17 @@ def parse_args():
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
if args.token_model_name == None:
|
||||
if "smoothquant" in args.model_name:
|
||||
args.token_model_name = (
|
||||
f"facebook/opt-{args.model_name.split('-')[3]}"
|
||||
)
|
||||
else:
|
||||
args.token_model_name = args.model_name
|
||||
if args.platform == PLATFORM_SHARK:
|
||||
shark_report = collect_shark_logits(
|
||||
args.model_name,
|
||||
args.token_model_name,
|
||||
args.max_seq_len,
|
||||
args.recompile_shark,
|
||||
args.save_json,
|
||||
@@ -367,6 +383,9 @@ if __name__ == "__main__":
|
||||
print("# Summary: {}".format(json.dumps(shark_report)))
|
||||
else:
|
||||
huggingface_report = collect_huggingface_logits(
|
||||
args.model_name, args.max_seq_len, args.save_json
|
||||
args.model_name,
|
||||
args.token_model_name,
|
||||
args.max_seq_len,
|
||||
args.save_json,
|
||||
)
|
||||
print("# Summary: {}".format(json.dumps(huggingface_report)))
|
||||
|
||||
12
tank/examples/opt/opt_util.py
Normal file
12
tank/examples/opt/opt_util.py
Normal file
@@ -0,0 +1,12 @@
|
||||
PROMPTS = [
|
||||
"What is the meaning of life?",
|
||||
"Tell me something you don't know.",
|
||||
"What does Xilinx do?",
|
||||
"What is the mass of earth?",
|
||||
"What is a poem?",
|
||||
"What is recursion?",
|
||||
"Tell me a one line joke.",
|
||||
"Who is Gilgamesh?",
|
||||
"Tell me something about cryptocurrency.",
|
||||
"How did it all begin?",
|
||||
]
|
||||
@@ -1,4 +1,3 @@
|
||||
from shark.iree_utils._common import check_device_drivers, device_driver_info
|
||||
from shark.shark_inference import SharkInference
|
||||
from shark.shark_downloader import download_model
|
||||
from tank.test_utils import get_valid_test_params, shark_test_name_func
|
||||
|
||||
@@ -44,7 +44,7 @@ class TapasBaseModuleTest(unittest.TestCase):
|
||||
self.module_tester.create_and_check_module(dynamic, device)
|
||||
|
||||
@pytest.mark.skipif(
|
||||
check_device_drivers("cuda"), reason=device_driver_info("gpu")
|
||||
check_device_drivers("cuda"), reason=device_driver_info("cuda")
|
||||
)
|
||||
def test_module_static_cuda(self):
|
||||
dynamic = False
|
||||
|
||||
Reference in New Issue
Block a user