mirror of
https://github.com/nod-ai/AMD-SHARK-Studio.git
synced 2026-04-03 03:00:17 -04:00
Merge branch 'main' into ean-turbine-gen
This commit is contained in:
8
.github/workflows/test-models.yml
vendored
8
.github/workflows/test-models.yml
vendored
@@ -112,7 +112,7 @@ jobs:
|
||||
cd $GITHUB_WORKSPACE
|
||||
PYTHON=python${{ matrix.python-version }} IMPORTER=1 ./setup_venv.sh
|
||||
source shark.venv/bin/activate
|
||||
pytest --forked --benchmark=native --ci --ci_sha=${SHORT_SHA} --update_tank --tank_url="gs://shark_tank/nightly/" -k cpu
|
||||
pytest --benchmark=native --update_tank -k cpu
|
||||
gsutil cp ./bench_results.csv gs://shark-public/builder/bench_results/${DATE}/bench_results_cpu_${SHORT_SHA}.csv
|
||||
gsutil cp gs://shark-public/builder/bench_results/${DATE}/bench_results_cpu_${SHORT_SHA}.csv gs://shark-public/builder/bench_results/latest/bench_results_cpu_latest.csv
|
||||
python build_tools/vicuna_testing.py
|
||||
@@ -123,7 +123,7 @@ jobs:
|
||||
cd $GITHUB_WORKSPACE
|
||||
PYTHON=python${{ matrix.python-version }} IMPORTER=1 ./setup_venv.sh
|
||||
source shark.venv/bin/activate
|
||||
pytest --forked --benchmark=native --ci --ci_sha=${SHORT_SHA} --update_tank --tank_url="gs://shark_tank/nightly/" -k cuda
|
||||
pytest --benchmark=native --update_tank -k cuda
|
||||
gsutil cp ./bench_results.csv gs://shark-public/builder/bench_results/${DATE}/bench_results_cuda_${SHORT_SHA}.csv
|
||||
gsutil cp gs://shark-public/builder/bench_results/${DATE}/bench_results_cuda_${SHORT_SHA}.csv gs://shark-public/builder/bench_results/latest/bench_results_cuda_latest.csv
|
||||
# Disabled due to black image bug
|
||||
@@ -146,8 +146,8 @@ jobs:
|
||||
cd $GITHUB_WORKSPACE
|
||||
PYTHON=python${{ matrix.python-version }} IMPORTER=1 ./setup_venv.sh
|
||||
source shark.venv/bin/activate
|
||||
pytest --forked --benchmark="native" --ci --ci_sha=${SHORT_SHA} --update_tank --tank_url="gs://shark_tank/nightly/" -k vulkan
|
||||
python build_tools/stable_diffusion_testing.py --device=vulkan
|
||||
pytest --update_tank -k vulkan
|
||||
python build_tools/stable_diffusion_testing.py --device=vulkan --no-exit_on_fail
|
||||
|
||||
- name: Validate Vulkan Models (Windows)
|
||||
if: matrix.suite == 'vulkan' && matrix.os == '7950x'
|
||||
|
||||
@@ -25,7 +25,7 @@ from apps.stable_diffusion.src import args
|
||||
|
||||
# Brevitas
|
||||
from typing import List, Tuple
|
||||
from brevitas_examples.llm.llm_quant.quantize import quantize_model
|
||||
from brevitas_examples.common.generative.quantize import quantize_model
|
||||
from brevitas_examples.llm.llm_quant.run_utils import get_model_impl
|
||||
|
||||
|
||||
@@ -101,7 +101,7 @@ class H2OGPTModel(torch.nn.Module):
|
||||
dtype=torch.float32,
|
||||
weight_bit_width=weight_bit_width,
|
||||
weight_param_method="stats",
|
||||
weight_scale_precision="float",
|
||||
weight_scale_precision="float_scale",
|
||||
weight_quant_type="asym",
|
||||
weight_quant_granularity="per_group",
|
||||
weight_group_size=128,
|
||||
|
||||
@@ -65,8 +65,8 @@ tiktoken==0.4.0
|
||||
openai==0.27.8
|
||||
|
||||
# optional for chat with PDF
|
||||
langchain==0.0.202
|
||||
pypdf==3.12.2
|
||||
langchain==0.0.329
|
||||
pypdf==3.17.0
|
||||
# avoid textract, requires old six
|
||||
#textract==1.6.5
|
||||
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
import argparse
|
||||
from dataclasses import dataclass
|
||||
import json
|
||||
import re
|
||||
import gc
|
||||
from io import BytesIO
|
||||
from os import environ
|
||||
from pathlib import Path
|
||||
from statistics import mean, stdev
|
||||
from tqdm import tqdm
|
||||
@@ -143,6 +145,12 @@ parser.add_argument(
|
||||
default=[],
|
||||
help="Extra command line arguments passed to the IREE compiler. This can be specified multiple times to pass multiple arguments."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--enable_tracing",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="Enable profiling with Tracy. The script will wait for Tracy to connect and flush the profiling data after each token."
|
||||
)
|
||||
|
||||
# Microbenchmarking options.
|
||||
parser.add_argument(
|
||||
@@ -868,7 +876,7 @@ class ShardedVicuna(VicunaBase):
|
||||
layer0, inputs0[0], inputs0[1], inputs0[2]
|
||||
)
|
||||
if self.precision in ["int4", "int8"]:
|
||||
from brevitas_examples.llm.llm_quant.quantize import quantize_model
|
||||
from brevitas_examples.common.generative.quantize import quantize_model
|
||||
from brevitas_examples.llm.llm_quant.run_utils import get_model_impl
|
||||
module0 = torch_mlir.compile(
|
||||
ts_g,
|
||||
@@ -1069,7 +1077,7 @@ class ShardedVicuna(VicunaBase):
|
||||
)
|
||||
|
||||
if self.precision in ["int4", "int8"]:
|
||||
from brevitas_examples.llm.llm_quant.quantize import quantize_model
|
||||
from brevitas_examples.common.generative.quantize import quantize_model
|
||||
from brevitas_examples.llm.llm_quant.run_utils import get_model_impl
|
||||
print("Applying weight quantization..")
|
||||
weight_bit_width = 4 if self.precision == "int4" else 8
|
||||
@@ -1079,7 +1087,7 @@ class ShardedVicuna(VicunaBase):
|
||||
weight_quant_type="asym",
|
||||
weight_bit_width=weight_bit_width,
|
||||
weight_param_method="stats",
|
||||
weight_scale_precision="float",
|
||||
weight_scale_precision="float_scale",
|
||||
weight_quant_granularity="per_group",
|
||||
weight_group_size=self.weight_group_size,
|
||||
quantize_weight_zero_point=False,
|
||||
@@ -1259,6 +1267,7 @@ class UnshardedVicuna(VicunaBase):
|
||||
max_num_tokens=512,
|
||||
min_num_tokens=0,
|
||||
device="cpu",
|
||||
device_id=None,
|
||||
vulkan_target_triple="",
|
||||
precision="int8",
|
||||
vicuna_mlir_path=None,
|
||||
@@ -1269,7 +1278,6 @@ class UnshardedVicuna(VicunaBase):
|
||||
download_vmfb=False,
|
||||
cache_vicunas=False,
|
||||
extra_args_cmd=[],
|
||||
device_id=None,
|
||||
debug=False,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
@@ -1288,9 +1296,7 @@ class UnshardedVicuna(VicunaBase):
|
||||
print(f"[DEBUG] hf model name: {self.hf_model_path}")
|
||||
self.max_sequence_length = 256
|
||||
self.min_num_tokens = min_num_tokens
|
||||
self.device = device
|
||||
self.vulkan_target_triple = vulkan_target_triple
|
||||
self.device_id = device_id
|
||||
self.precision = precision
|
||||
self.download_vmfb = download_vmfb
|
||||
self.vicuna_vmfb_path = vicuna_vmfb_path
|
||||
@@ -1299,12 +1305,24 @@ class UnshardedVicuna(VicunaBase):
|
||||
self.low_device_memory = low_device_memory
|
||||
self.weight_group_size = weight_group_size
|
||||
self.debug = debug
|
||||
# Sanity check for device, device_id pair
|
||||
if "://" in device:
|
||||
if device_id is not None:
|
||||
print("[ERR] can't have both full device path and a device id.\n"
|
||||
f"Device : {device} | device_id : {device_id}\n"
|
||||
"proceeding with given Device ignoring device_id")
|
||||
self.device, self.device_id = device.split("://")
|
||||
if len(self.device_id) < 2:
|
||||
self.device_id = int(self.device_id)
|
||||
else:
|
||||
self.device, self.device_id = device, device_id
|
||||
if self.vicuna_mlir_path == None:
|
||||
self.vicuna_mlir_path = self.get_model_path()
|
||||
if self.vicuna_vmfb_path == None:
|
||||
self.vicuna_vmfb_path = self.get_model_path(suffix="vmfb")
|
||||
self.tokenizer = self.get_tokenizer()
|
||||
self.cache_vicunas = cache_vicunas
|
||||
|
||||
self.compile()
|
||||
|
||||
def get_model_path(self, suffix="mlir"):
|
||||
@@ -1313,13 +1331,27 @@ class UnshardedVicuna(VicunaBase):
|
||||
if suffix in ["mlirbc", "mlir"]:
|
||||
return Path(f"{self.model_name}_{self.precision}.{suffix}")
|
||||
|
||||
target_triple = ""
|
||||
if self.vulkan_target_triple != "":
|
||||
target_triple = "_"
|
||||
target_triple += "_".join(self.vulkan_target_triple.split("-")[:-1])
|
||||
|
||||
# Need to distinguish between multiple vmfbs of the same model
|
||||
# compiled for different devices of the same driver
|
||||
# Driver - Differentiator
|
||||
# Vulkan - target_triple
|
||||
# ROCm - device_arch
|
||||
|
||||
differentiator = ""
|
||||
if "vulkan" == self.device:
|
||||
target_triple = ""
|
||||
if self.vulkan_target_triple != "":
|
||||
target_triple = "_"
|
||||
target_triple += "_".join(self.vulkan_target_triple.split("-")[:-1])
|
||||
differentiator = target_triple
|
||||
|
||||
elif "rocm" == self.device:
|
||||
from shark.iree_utils.gpu_utils import get_rocm_device_arch
|
||||
device_arch = get_rocm_device_arch(self.device_id if self.device_id is not None else 0, self.extra_args)
|
||||
differentiator = '_' + device_arch
|
||||
|
||||
return Path(
|
||||
f"{self.model_name}_{self.precision}_{safe_device}{target_triple}.{suffix}"
|
||||
f"{self.model_name}_{self.precision}_{safe_device}{differentiator}.{suffix}"
|
||||
)
|
||||
|
||||
def get_tokenizer(self):
|
||||
@@ -1752,9 +1784,8 @@ class UnshardedVicuna(VicunaBase):
|
||||
)
|
||||
del first_module, second_module
|
||||
|
||||
print(self.device)
|
||||
if "rocm" in self.device:
|
||||
self.device = "rocm"
|
||||
print(f"Compiling for device : {self.device}"
|
||||
f"{'://' + str(self.device_id) if self.device_id is not None else ''}")
|
||||
shark_module = SharkInference(
|
||||
mlir_module=combined_module,
|
||||
device=self.device,
|
||||
@@ -1914,12 +1945,107 @@ def create_prompt(model_name, history):
|
||||
return msg
|
||||
|
||||
|
||||
def miliseconds_to_seconds(ms: float) -> float:
|
||||
return ms / 1000.0
|
||||
|
||||
|
||||
@dataclass
|
||||
class BenchmarkRunInfo:
|
||||
num_prompt_tokens : int
|
||||
prefill_time_ms : float
|
||||
token_times_ms : list[float]
|
||||
|
||||
def get_prefill_speed(self) -> float:
|
||||
seconds = miliseconds_to_seconds(self.prefill_time_ms)
|
||||
if seconds == 0.0:
|
||||
return float('inf')
|
||||
return self.num_prompt_tokens / seconds
|
||||
|
||||
def num_generated_tokens(self) -> int:
|
||||
return len(self.token_times_ms)
|
||||
|
||||
def get_decode_time_ms(self) -> float:
|
||||
return sum(self.token_times_ms)
|
||||
|
||||
def get_decode_speed(self) -> float:
|
||||
seconds = miliseconds_to_seconds(self.get_decode_time_ms())
|
||||
if seconds == 0.0:
|
||||
return float('inf')
|
||||
return self.num_generated_tokens() / seconds
|
||||
|
||||
def get_e2e_time_ms(self) -> float:
|
||||
return self.prefill_time_ms + self.get_decode_time_ms()
|
||||
|
||||
def get_e2e_decode_speed(self) -> float:
|
||||
seconds = miliseconds_to_seconds(self.get_e2e_time_ms())
|
||||
if seconds == 0.0:
|
||||
return float('inf')
|
||||
return self.num_generated_tokens() / seconds
|
||||
|
||||
def get_e2e_token_processing_speed(self) -> float:
|
||||
seconds = miliseconds_to_seconds(self.get_e2e_time_ms())
|
||||
if seconds == 0.0:
|
||||
return float('inf')
|
||||
return (self.num_prompt_tokens + self.num_generated_tokens()) / seconds
|
||||
|
||||
def print(self) -> None:
|
||||
total_tokens = self.num_prompt_tokens + self.num_generated_tokens()
|
||||
print(f"Num tokens: {self.num_prompt_tokens:} (prompt), {self.num_generated_tokens()} (generated), {total_tokens} (total)")
|
||||
print(f"Prefill: {self.prefill_time_ms:.2f} ms, {self.get_prefill_speed():.2f} tokens/s")
|
||||
print(f"Decode: {self.get_decode_time_ms():.2f} ms, {self.get_decode_speed():.2f} tokens/s")
|
||||
print(f"Decode end-2-end: {self.get_e2e_decode_speed():.2f} tokens/s (w/o prompt), {self.get_e2e_token_processing_speed():.2f} tokens/s (w/ prompt)")
|
||||
|
||||
|
||||
def print_aggregate_stats(run_infos: list[BenchmarkRunInfo]) -> None:
|
||||
num_iterations = len(run_infos)
|
||||
print(f'Number of iterations: {num_iterations}')
|
||||
if num_iterations == 0:
|
||||
return
|
||||
|
||||
if len(run_infos) == 1:
|
||||
run_infos[0].print()
|
||||
return
|
||||
|
||||
total_tokens = run_infos[0].num_prompt_tokens + run_infos[0].num_generated_tokens()
|
||||
print(f"Num tokens: {run_infos[0].num_prompt_tokens} (prompt), {run_infos[0].num_generated_tokens()} (generated), {total_tokens} (total)")
|
||||
|
||||
def avg_and_stdev(data):
|
||||
x = list(data)
|
||||
return mean(x), stdev(x)
|
||||
|
||||
avg_prefill_ms, stdev_prefill = avg_and_stdev(x.prefill_time_ms for x in run_infos)
|
||||
avg_prefill_speed = mean(x.get_prefill_speed() for x in run_infos)
|
||||
print(f"Prefill: avg. {avg_prefill_ms:.2f} ms (stdev {stdev_prefill:.2f}), avg. {avg_prefill_speed:.2f} tokens/s")
|
||||
|
||||
avg_decode_ms, stdev_decode = avg_and_stdev(x.get_decode_time_ms() for x in run_infos)
|
||||
avg_decode_speed = mean(x.get_decode_speed() for x in run_infos)
|
||||
print(f"Decode: avg. {avg_decode_ms:.2f} ms (stdev {stdev_decode:.2f}), avg. {avg_decode_speed:.2f} tokens/s")
|
||||
|
||||
avg_e2e_decode_speed = mean(x.get_e2e_decode_speed() for x in run_infos)
|
||||
avg_e2e_processing_speed = mean(x.get_e2e_token_processing_speed() for x in run_infos)
|
||||
print(f"Decode end-2-end: avg. {avg_e2e_decode_speed:.2f} tokens/s (w/o prompt), avg. {avg_e2e_processing_speed:.2f} (w/ prompt)")
|
||||
|
||||
|
||||
def enable_tracy_tracing():
|
||||
# Make tracy wait for a caputre to be collected before exiting.
|
||||
environ["TRACY_NO_EXIT"] = "1"
|
||||
|
||||
if "IREE_PY_RUNTIME" not in environ or environ["IREE_PY_RUNTIME"] != "tracy":
|
||||
print("ERROR: Tracing enabled but tracy iree runtime not used.", file=sys.stderr)
|
||||
print("Set the IREE_PY_RUNTIME=tracy environment variable.", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args, unknown = parser.parse_known_args()
|
||||
|
||||
_extra_args = list(args.Xiree_compile)
|
||||
|
||||
device_id = None
|
||||
|
||||
if args.enable_tracing:
|
||||
enable_tracy_tracing()
|
||||
|
||||
# Process vulkan target triple.
|
||||
# TODO: This feature should just be in a common utils for other LLMs and in general
|
||||
# any model run via SHARK for Vulkan backend.
|
||||
@@ -2012,8 +2138,7 @@ if __name__ == "__main__":
|
||||
|
||||
iteration = 0
|
||||
|
||||
prefill_times = []
|
||||
avg_decode_speed = []
|
||||
benchmark_run_infos = []
|
||||
|
||||
while True:
|
||||
# TODO: Add break condition from user input
|
||||
@@ -2029,28 +2154,30 @@ if __name__ == "__main__":
|
||||
prompt = args.system_prompt + user_prompt
|
||||
history = [[user_prompt, ""]]
|
||||
|
||||
token_count = 0
|
||||
total_time_ms = 0.001 # In order to avoid divide by zero error
|
||||
prefill_time = 0
|
||||
prompt_token_count = len(vic.tokenizer(prompt).input_ids)
|
||||
total_time_ms = 0.0 # In order to avoid divide by zero error
|
||||
prefill_time_ms = 0
|
||||
is_first = True
|
||||
token_times_ms = []
|
||||
|
||||
for text, msg, exec_time in vic.generate(prompt, cli=True):
|
||||
if args.enable_tracing:
|
||||
vic.shark_model.shark_runner.iree_config.device.flush_profiling()
|
||||
|
||||
if msg is None:
|
||||
if is_first:
|
||||
prefill_time = exec_time
|
||||
# Note that the prefill time is in seconds, and all the decoded tokens in ms.
|
||||
prefill_time_ms = exec_time * 1000
|
||||
is_first = False
|
||||
else:
|
||||
total_time_ms += exec_time
|
||||
token_count += 1
|
||||
token_times_ms.append(exec_time)
|
||||
elif "formatted" in msg:
|
||||
history[-1][1] = text
|
||||
tokens_per_sec = (token_count / total_time_ms) * 1000
|
||||
prefill_times.append(prefill_time)
|
||||
avg_decode_speed.append(tokens_per_sec)
|
||||
print(f"\nResponse:\n{text.strip()}\n")
|
||||
run_info = BenchmarkRunInfo(prompt_token_count, prefill_time_ms, token_times_ms)
|
||||
run_info.print()
|
||||
benchmark_run_infos.append(run_info)
|
||||
|
||||
print("\nResponse:", text.strip())
|
||||
print(f"\nNum tokens: {token_count}")
|
||||
print(f"Prefill: {prefill_time:.2f} seconds")
|
||||
print(f"Decode: {tokens_per_sec:.2f} tokens/s")
|
||||
else:
|
||||
sys.exit(
|
||||
"unexpected message from the vicuna generate call, exiting."
|
||||
@@ -2058,6 +2185,4 @@ if __name__ == "__main__":
|
||||
|
||||
if args.enable_microbenchmark:
|
||||
print("\n### Final Statistics ###")
|
||||
print("Number of iterations:", iteration - 1)
|
||||
print(f"Prefill: avg. {mean(prefill_times):.2f} s, stdev {stdev(prefill_times):.2f}")
|
||||
print(f"Decode: avg. {mean(avg_decode_speed):.2f} tokens/s, stdev {stdev(avg_decode_speed):.2f}")
|
||||
print_aggregate_stats(benchmark_run_infos)
|
||||
|
||||
@@ -69,91 +69,7 @@ class CompiledLMHeadEmbeddingLayer(torch.nn.Module):
|
||||
return torch.tensor(new_hidden_states)
|
||||
|
||||
|
||||
class DecoderLayer(torch.nn.Module):
|
||||
def __init__(self, decoder_layer_model, falcon_variant):
|
||||
super().__init__()
|
||||
self.model = decoder_layer_model
|
||||
|
||||
def forward(self, hidden_states, attention_mask):
|
||||
output = self.model.forward(
|
||||
hidden_states=hidden_states,
|
||||
alibi=None,
|
||||
attention_mask=attention_mask,
|
||||
use_cache=True,
|
||||
)
|
||||
return (output[0], output[1][0], output[1][1])
|
||||
|
||||
|
||||
class CompiledDecoderLayer(torch.nn.Module):
|
||||
def __init__(
|
||||
self, layer_id, device_idx, falcon_variant, device, precision
|
||||
):
|
||||
super().__init__()
|
||||
self.layer_id = layer_id
|
||||
self.device_index = device_idx
|
||||
self.falcon_variant = falcon_variant
|
||||
self.device = device
|
||||
self.precision = precision
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: torch.Tensor,
|
||||
alibi: torch.Tensor = None,
|
||||
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
head_mask: Optional[torch.Tensor] = None,
|
||||
use_cache: bool = False,
|
||||
output_attentions: bool = False,
|
||||
):
|
||||
import gc
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
from pathlib import Path
|
||||
from apps.language_models.utils import get_vmfb_from_path
|
||||
|
||||
self.falcon_vmfb_path = Path(
|
||||
f"falcon_{self.falcon_variant}_layer_{self.layer_id}_{self.precision}_{self.device}.vmfb"
|
||||
)
|
||||
print("vmfb path for layer: ", self.falcon_vmfb_path)
|
||||
self.model = get_vmfb_from_path(
|
||||
self.falcon_vmfb_path,
|
||||
self.device,
|
||||
"linalg",
|
||||
device_id=self.device_index,
|
||||
)
|
||||
if self.model is None:
|
||||
raise ValueError("Layer vmfb not found")
|
||||
|
||||
hidden_states = hidden_states.to(torch.float32).detach().numpy()
|
||||
attention_mask = attention_mask.detach().numpy()
|
||||
|
||||
if alibi is not None or layer_past is not None:
|
||||
raise ValueError("Past Key Values and alibi should be None")
|
||||
else:
|
||||
new_hidden_states, pkv1, pkv2 = self.model(
|
||||
"forward",
|
||||
(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
),
|
||||
)
|
||||
del self.model
|
||||
|
||||
return tuple(
|
||||
[
|
||||
torch.tensor(new_hidden_states),
|
||||
tuple(
|
||||
[
|
||||
torch.tensor(pkv1),
|
||||
torch.tensor(pkv2),
|
||||
]
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class EightDecoderLayer(torch.nn.Module):
|
||||
class FourWayShardingDecoderLayer(torch.nn.Module):
|
||||
def __init__(self, decoder_layer_model, falcon_variant):
|
||||
super().__init__()
|
||||
self.model = decoder_layer_model
|
||||
@@ -175,163 +91,78 @@ class EightDecoderLayer(torch.nn.Module):
|
||||
outputs[-1][1],
|
||||
)
|
||||
)
|
||||
if self.falcon_variant == "7b":
|
||||
(
|
||||
(new_pkv00, new_pkv01),
|
||||
(new_pkv10, new_pkv11),
|
||||
(new_pkv20, new_pkv21),
|
||||
(new_pkv30, new_pkv31),
|
||||
(new_pkv40, new_pkv41),
|
||||
(new_pkv50, new_pkv51),
|
||||
(new_pkv60, new_pkv61),
|
||||
(new_pkv70, new_pkv71),
|
||||
) = new_pkvs
|
||||
result = (
|
||||
hidden_states,
|
||||
new_pkv00,
|
||||
new_pkv01,
|
||||
new_pkv10,
|
||||
new_pkv11,
|
||||
new_pkv20,
|
||||
new_pkv21,
|
||||
new_pkv30,
|
||||
new_pkv31,
|
||||
new_pkv40,
|
||||
new_pkv41,
|
||||
new_pkv50,
|
||||
new_pkv51,
|
||||
new_pkv60,
|
||||
new_pkv61,
|
||||
new_pkv70,
|
||||
new_pkv71,
|
||||
)
|
||||
elif self.falcon_variant == "40b":
|
||||
(
|
||||
(new_pkv00, new_pkv01),
|
||||
(new_pkv10, new_pkv11),
|
||||
(new_pkv20, new_pkv21),
|
||||
(new_pkv30, new_pkv31),
|
||||
(new_pkv40, new_pkv41),
|
||||
(new_pkv50, new_pkv51),
|
||||
(new_pkv60, new_pkv61),
|
||||
(new_pkv70, new_pkv71),
|
||||
(new_pkv80, new_pkv81),
|
||||
(new_pkv90, new_pkv91),
|
||||
(new_pkv100, new_pkv101),
|
||||
(new_pkv110, new_pkv111),
|
||||
(new_pkv120, new_pkv121),
|
||||
(new_pkv130, new_pkv131),
|
||||
(new_pkv140, new_pkv141),
|
||||
) = new_pkvs
|
||||
result = (
|
||||
hidden_states,
|
||||
new_pkv00,
|
||||
new_pkv01,
|
||||
new_pkv10,
|
||||
new_pkv11,
|
||||
new_pkv20,
|
||||
new_pkv21,
|
||||
new_pkv30,
|
||||
new_pkv31,
|
||||
new_pkv40,
|
||||
new_pkv41,
|
||||
new_pkv50,
|
||||
new_pkv51,
|
||||
new_pkv60,
|
||||
new_pkv61,
|
||||
new_pkv70,
|
||||
new_pkv71,
|
||||
new_pkv80,
|
||||
new_pkv81,
|
||||
new_pkv90,
|
||||
new_pkv91,
|
||||
new_pkv100,
|
||||
new_pkv101,
|
||||
new_pkv110,
|
||||
new_pkv111,
|
||||
new_pkv120,
|
||||
new_pkv121,
|
||||
new_pkv130,
|
||||
new_pkv131,
|
||||
new_pkv140,
|
||||
new_pkv141,
|
||||
)
|
||||
elif self.falcon_variant == "180b":
|
||||
(
|
||||
(new_pkv00, new_pkv01),
|
||||
(new_pkv10, new_pkv11),
|
||||
(new_pkv20, new_pkv21),
|
||||
(new_pkv30, new_pkv31),
|
||||
(new_pkv40, new_pkv41),
|
||||
(new_pkv50, new_pkv51),
|
||||
(new_pkv60, new_pkv61),
|
||||
(new_pkv70, new_pkv71),
|
||||
(new_pkv80, new_pkv81),
|
||||
(new_pkv90, new_pkv91),
|
||||
(new_pkv100, new_pkv101),
|
||||
(new_pkv110, new_pkv111),
|
||||
(new_pkv120, new_pkv121),
|
||||
(new_pkv130, new_pkv131),
|
||||
(new_pkv140, new_pkv141),
|
||||
(new_pkv150, new_pkv151),
|
||||
(new_pkv160, new_pkv161),
|
||||
(new_pkv170, new_pkv171),
|
||||
(new_pkv180, new_pkv181),
|
||||
(new_pkv190, new_pkv191),
|
||||
) = new_pkvs
|
||||
result = (
|
||||
hidden_states,
|
||||
new_pkv00,
|
||||
new_pkv01,
|
||||
new_pkv10,
|
||||
new_pkv11,
|
||||
new_pkv20,
|
||||
new_pkv21,
|
||||
new_pkv30,
|
||||
new_pkv31,
|
||||
new_pkv40,
|
||||
new_pkv41,
|
||||
new_pkv50,
|
||||
new_pkv51,
|
||||
new_pkv60,
|
||||
new_pkv61,
|
||||
new_pkv70,
|
||||
new_pkv71,
|
||||
new_pkv80,
|
||||
new_pkv81,
|
||||
new_pkv90,
|
||||
new_pkv91,
|
||||
new_pkv100,
|
||||
new_pkv101,
|
||||
new_pkv110,
|
||||
new_pkv111,
|
||||
new_pkv120,
|
||||
new_pkv121,
|
||||
new_pkv130,
|
||||
new_pkv131,
|
||||
new_pkv140,
|
||||
new_pkv141,
|
||||
new_pkv150,
|
||||
new_pkv151,
|
||||
new_pkv160,
|
||||
new_pkv161,
|
||||
new_pkv170,
|
||||
new_pkv171,
|
||||
new_pkv180,
|
||||
new_pkv181,
|
||||
new_pkv190,
|
||||
new_pkv191,
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Unsupported Falcon variant: ", self.falcon_variant
|
||||
)
|
||||
|
||||
(
|
||||
(new_pkv00, new_pkv01),
|
||||
(new_pkv10, new_pkv11),
|
||||
(new_pkv20, new_pkv21),
|
||||
(new_pkv30, new_pkv31),
|
||||
(new_pkv40, new_pkv41),
|
||||
(new_pkv50, new_pkv51),
|
||||
(new_pkv60, new_pkv61),
|
||||
(new_pkv70, new_pkv71),
|
||||
(new_pkv80, new_pkv81),
|
||||
(new_pkv90, new_pkv91),
|
||||
(new_pkv100, new_pkv101),
|
||||
(new_pkv110, new_pkv111),
|
||||
(new_pkv120, new_pkv121),
|
||||
(new_pkv130, new_pkv131),
|
||||
(new_pkv140, new_pkv141),
|
||||
(new_pkv150, new_pkv151),
|
||||
(new_pkv160, new_pkv161),
|
||||
(new_pkv170, new_pkv171),
|
||||
(new_pkv180, new_pkv181),
|
||||
(new_pkv190, new_pkv191),
|
||||
) = new_pkvs
|
||||
result = (
|
||||
hidden_states,
|
||||
new_pkv00,
|
||||
new_pkv01,
|
||||
new_pkv10,
|
||||
new_pkv11,
|
||||
new_pkv20,
|
||||
new_pkv21,
|
||||
new_pkv30,
|
||||
new_pkv31,
|
||||
new_pkv40,
|
||||
new_pkv41,
|
||||
new_pkv50,
|
||||
new_pkv51,
|
||||
new_pkv60,
|
||||
new_pkv61,
|
||||
new_pkv70,
|
||||
new_pkv71,
|
||||
new_pkv80,
|
||||
new_pkv81,
|
||||
new_pkv90,
|
||||
new_pkv91,
|
||||
new_pkv100,
|
||||
new_pkv101,
|
||||
new_pkv110,
|
||||
new_pkv111,
|
||||
new_pkv120,
|
||||
new_pkv121,
|
||||
new_pkv130,
|
||||
new_pkv131,
|
||||
new_pkv140,
|
||||
new_pkv141,
|
||||
new_pkv150,
|
||||
new_pkv151,
|
||||
new_pkv160,
|
||||
new_pkv161,
|
||||
new_pkv170,
|
||||
new_pkv171,
|
||||
new_pkv180,
|
||||
new_pkv181,
|
||||
new_pkv190,
|
||||
new_pkv191,
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
class CompiledEightDecoderLayer(torch.nn.Module):
|
||||
class CompiledFourWayShardingDecoderLayer(torch.nn.Module):
|
||||
def __init__(
|
||||
self, layer_id, device_idx, falcon_variant, device, precision
|
||||
self, layer_id, device_idx, falcon_variant, device, precision, model
|
||||
):
|
||||
super().__init__()
|
||||
self.layer_id = layer_id
|
||||
@@ -339,6 +170,7 @@ class CompiledEightDecoderLayer(torch.nn.Module):
|
||||
self.falcon_variant = falcon_variant
|
||||
self.device = device
|
||||
self.precision = precision
|
||||
self.model = model
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -354,19 +186,7 @@ class CompiledEightDecoderLayer(torch.nn.Module):
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
from pathlib import Path
|
||||
from apps.language_models.utils import get_vmfb_from_path
|
||||
|
||||
self.falcon_vmfb_path = Path(
|
||||
f"falcon_{self.falcon_variant}_layer_{self.layer_id}_{self.precision}_{self.device}.vmfb"
|
||||
)
|
||||
print("vmfb path for layer: ", self.falcon_vmfb_path)
|
||||
self.model = get_vmfb_from_path(
|
||||
self.falcon_vmfb_path,
|
||||
self.device,
|
||||
"linalg",
|
||||
device_id=self.device_index,
|
||||
)
|
||||
if self.model is None:
|
||||
raise ValueError("Layer vmfb not found")
|
||||
|
||||
@@ -383,196 +203,451 @@ class CompiledEightDecoderLayer(torch.nn.Module):
|
||||
attention_mask,
|
||||
),
|
||||
)
|
||||
del self.model
|
||||
|
||||
if self.falcon_variant == "7b":
|
||||
result = (
|
||||
torch.tensor(output[0]),
|
||||
(
|
||||
torch.tensor(output[1]),
|
||||
torch.tensor(output[2]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[3]),
|
||||
torch.tensor(output[4]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[5]),
|
||||
torch.tensor(output[6]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[7]),
|
||||
torch.tensor(output[8]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[9]),
|
||||
torch.tensor(output[10]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[11]),
|
||||
torch.tensor(output[12]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[13]),
|
||||
torch.tensor(output[14]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[15]),
|
||||
torch.tensor(output[16]),
|
||||
),
|
||||
result = (
|
||||
torch.tensor(output[0]),
|
||||
(
|
||||
torch.tensor(output[1]),
|
||||
torch.tensor(output[2]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[3]),
|
||||
torch.tensor(output[4]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[5]),
|
||||
torch.tensor(output[6]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[7]),
|
||||
torch.tensor(output[8]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[9]),
|
||||
torch.tensor(output[10]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[11]),
|
||||
torch.tensor(output[12]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[13]),
|
||||
torch.tensor(output[14]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[15]),
|
||||
torch.tensor(output[16]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[17]),
|
||||
torch.tensor(output[18]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[19]),
|
||||
torch.tensor(output[20]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[21]),
|
||||
torch.tensor(output[22]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[23]),
|
||||
torch.tensor(output[24]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[25]),
|
||||
torch.tensor(output[26]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[27]),
|
||||
torch.tensor(output[28]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[29]),
|
||||
torch.tensor(output[30]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[31]),
|
||||
torch.tensor(output[32]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[33]),
|
||||
torch.tensor(output[34]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[35]),
|
||||
torch.tensor(output[36]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[37]),
|
||||
torch.tensor(output[38]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[39]),
|
||||
torch.tensor(output[40]),
|
||||
),
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
class TwoWayShardingDecoderLayer(torch.nn.Module):
|
||||
def __init__(self, decoder_layer_model, falcon_variant):
|
||||
super().__init__()
|
||||
self.model = decoder_layer_model
|
||||
self.falcon_variant = falcon_variant
|
||||
|
||||
def forward(self, hidden_states, attention_mask):
|
||||
new_pkvs = []
|
||||
for layer in self.model:
|
||||
outputs = layer(
|
||||
hidden_states=hidden_states,
|
||||
alibi=None,
|
||||
attention_mask=attention_mask,
|
||||
use_cache=True,
|
||||
)
|
||||
elif self.falcon_variant == "40b":
|
||||
result = (
|
||||
torch.tensor(output[0]),
|
||||
hidden_states = outputs[0]
|
||||
new_pkvs.append(
|
||||
(
|
||||
torch.tensor(output[1]),
|
||||
torch.tensor(output[2]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[3]),
|
||||
torch.tensor(output[4]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[5]),
|
||||
torch.tensor(output[6]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[7]),
|
||||
torch.tensor(output[8]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[9]),
|
||||
torch.tensor(output[10]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[11]),
|
||||
torch.tensor(output[12]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[13]),
|
||||
torch.tensor(output[14]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[15]),
|
||||
torch.tensor(output[16]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[17]),
|
||||
torch.tensor(output[18]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[19]),
|
||||
torch.tensor(output[20]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[21]),
|
||||
torch.tensor(output[22]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[23]),
|
||||
torch.tensor(output[24]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[25]),
|
||||
torch.tensor(output[26]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[27]),
|
||||
torch.tensor(output[28]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[29]),
|
||||
torch.tensor(output[30]),
|
||||
),
|
||||
)
|
||||
elif self.falcon_variant == "180b":
|
||||
result = (
|
||||
torch.tensor(output[0]),
|
||||
(
|
||||
torch.tensor(output[1]),
|
||||
torch.tensor(output[2]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[3]),
|
||||
torch.tensor(output[4]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[5]),
|
||||
torch.tensor(output[6]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[7]),
|
||||
torch.tensor(output[8]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[9]),
|
||||
torch.tensor(output[10]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[11]),
|
||||
torch.tensor(output[12]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[13]),
|
||||
torch.tensor(output[14]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[15]),
|
||||
torch.tensor(output[16]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[17]),
|
||||
torch.tensor(output[18]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[19]),
|
||||
torch.tensor(output[20]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[21]),
|
||||
torch.tensor(output[22]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[23]),
|
||||
torch.tensor(output[24]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[25]),
|
||||
torch.tensor(output[26]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[27]),
|
||||
torch.tensor(output[28]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[29]),
|
||||
torch.tensor(output[30]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[31]),
|
||||
torch.tensor(output[32]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[33]),
|
||||
torch.tensor(output[34]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[35]),
|
||||
torch.tensor(output[36]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[37]),
|
||||
torch.tensor(output[38]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[39]),
|
||||
torch.tensor(output[40]),
|
||||
),
|
||||
outputs[-1][0],
|
||||
outputs[-1][1],
|
||||
)
|
||||
)
|
||||
|
||||
(
|
||||
(new_pkv00, new_pkv01),
|
||||
(new_pkv10, new_pkv11),
|
||||
(new_pkv20, new_pkv21),
|
||||
(new_pkv30, new_pkv31),
|
||||
(new_pkv40, new_pkv41),
|
||||
(new_pkv50, new_pkv51),
|
||||
(new_pkv60, new_pkv61),
|
||||
(new_pkv70, new_pkv71),
|
||||
(new_pkv80, new_pkv81),
|
||||
(new_pkv90, new_pkv91),
|
||||
(new_pkv100, new_pkv101),
|
||||
(new_pkv110, new_pkv111),
|
||||
(new_pkv120, new_pkv121),
|
||||
(new_pkv130, new_pkv131),
|
||||
(new_pkv140, new_pkv141),
|
||||
(new_pkv150, new_pkv151),
|
||||
(new_pkv160, new_pkv161),
|
||||
(new_pkv170, new_pkv171),
|
||||
(new_pkv180, new_pkv181),
|
||||
(new_pkv190, new_pkv191),
|
||||
(new_pkv200, new_pkv201),
|
||||
(new_pkv210, new_pkv211),
|
||||
(new_pkv220, new_pkv221),
|
||||
(new_pkv230, new_pkv231),
|
||||
(new_pkv240, new_pkv241),
|
||||
(new_pkv250, new_pkv251),
|
||||
(new_pkv260, new_pkv261),
|
||||
(new_pkv270, new_pkv271),
|
||||
(new_pkv280, new_pkv281),
|
||||
(new_pkv290, new_pkv291),
|
||||
(new_pkv300, new_pkv301),
|
||||
(new_pkv310, new_pkv311),
|
||||
(new_pkv320, new_pkv321),
|
||||
(new_pkv330, new_pkv331),
|
||||
(new_pkv340, new_pkv341),
|
||||
(new_pkv350, new_pkv351),
|
||||
(new_pkv360, new_pkv361),
|
||||
(new_pkv370, new_pkv371),
|
||||
(new_pkv380, new_pkv381),
|
||||
(new_pkv390, new_pkv391),
|
||||
) = new_pkvs
|
||||
result = (
|
||||
hidden_states,
|
||||
new_pkv00,
|
||||
new_pkv01,
|
||||
new_pkv10,
|
||||
new_pkv11,
|
||||
new_pkv20,
|
||||
new_pkv21,
|
||||
new_pkv30,
|
||||
new_pkv31,
|
||||
new_pkv40,
|
||||
new_pkv41,
|
||||
new_pkv50,
|
||||
new_pkv51,
|
||||
new_pkv60,
|
||||
new_pkv61,
|
||||
new_pkv70,
|
||||
new_pkv71,
|
||||
new_pkv80,
|
||||
new_pkv81,
|
||||
new_pkv90,
|
||||
new_pkv91,
|
||||
new_pkv100,
|
||||
new_pkv101,
|
||||
new_pkv110,
|
||||
new_pkv111,
|
||||
new_pkv120,
|
||||
new_pkv121,
|
||||
new_pkv130,
|
||||
new_pkv131,
|
||||
new_pkv140,
|
||||
new_pkv141,
|
||||
new_pkv150,
|
||||
new_pkv151,
|
||||
new_pkv160,
|
||||
new_pkv161,
|
||||
new_pkv170,
|
||||
new_pkv171,
|
||||
new_pkv180,
|
||||
new_pkv181,
|
||||
new_pkv190,
|
||||
new_pkv191,
|
||||
new_pkv200,
|
||||
new_pkv201,
|
||||
new_pkv210,
|
||||
new_pkv211,
|
||||
new_pkv220,
|
||||
new_pkv221,
|
||||
new_pkv230,
|
||||
new_pkv231,
|
||||
new_pkv240,
|
||||
new_pkv241,
|
||||
new_pkv250,
|
||||
new_pkv251,
|
||||
new_pkv260,
|
||||
new_pkv261,
|
||||
new_pkv270,
|
||||
new_pkv271,
|
||||
new_pkv280,
|
||||
new_pkv281,
|
||||
new_pkv290,
|
||||
new_pkv291,
|
||||
new_pkv300,
|
||||
new_pkv301,
|
||||
new_pkv310,
|
||||
new_pkv311,
|
||||
new_pkv320,
|
||||
new_pkv321,
|
||||
new_pkv330,
|
||||
new_pkv331,
|
||||
new_pkv340,
|
||||
new_pkv341,
|
||||
new_pkv350,
|
||||
new_pkv351,
|
||||
new_pkv360,
|
||||
new_pkv361,
|
||||
new_pkv370,
|
||||
new_pkv371,
|
||||
new_pkv380,
|
||||
new_pkv381,
|
||||
new_pkv390,
|
||||
new_pkv391,
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
class CompiledTwoWayShardingDecoderLayer(torch.nn.Module):
|
||||
def __init__(
|
||||
self, layer_id, device_idx, falcon_variant, device, precision, model
|
||||
):
|
||||
super().__init__()
|
||||
self.layer_id = layer_id
|
||||
self.device_index = device_idx
|
||||
self.falcon_variant = falcon_variant
|
||||
self.device = device
|
||||
self.precision = precision
|
||||
self.model = model
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: torch.Tensor,
|
||||
alibi: torch.Tensor = None,
|
||||
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
head_mask: Optional[torch.Tensor] = None,
|
||||
use_cache: bool = False,
|
||||
output_attentions: bool = False,
|
||||
):
|
||||
import gc
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
|
||||
if self.model is None:
|
||||
raise ValueError("Layer vmfb not found")
|
||||
|
||||
hidden_states = hidden_states.to(torch.float32).detach().numpy()
|
||||
attention_mask = attention_mask.detach().numpy()
|
||||
|
||||
if alibi is not None or layer_past is not None:
|
||||
raise ValueError("Past Key Values and alibi should be None")
|
||||
else:
|
||||
raise ValueError(
|
||||
"Unsupported Falcon variant: ", self.falcon_variant
|
||||
output = self.model(
|
||||
"forward",
|
||||
(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
),
|
||||
)
|
||||
|
||||
result = (
|
||||
torch.tensor(output[0]),
|
||||
(
|
||||
torch.tensor(output[1]),
|
||||
torch.tensor(output[2]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[3]),
|
||||
torch.tensor(output[4]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[5]),
|
||||
torch.tensor(output[6]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[7]),
|
||||
torch.tensor(output[8]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[9]),
|
||||
torch.tensor(output[10]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[11]),
|
||||
torch.tensor(output[12]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[13]),
|
||||
torch.tensor(output[14]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[15]),
|
||||
torch.tensor(output[16]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[17]),
|
||||
torch.tensor(output[18]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[19]),
|
||||
torch.tensor(output[20]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[21]),
|
||||
torch.tensor(output[22]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[23]),
|
||||
torch.tensor(output[24]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[25]),
|
||||
torch.tensor(output[26]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[27]),
|
||||
torch.tensor(output[28]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[29]),
|
||||
torch.tensor(output[30]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[31]),
|
||||
torch.tensor(output[32]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[33]),
|
||||
torch.tensor(output[34]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[35]),
|
||||
torch.tensor(output[36]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[37]),
|
||||
torch.tensor(output[38]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[39]),
|
||||
torch.tensor(output[40]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[41]),
|
||||
torch.tensor(output[42]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[43]),
|
||||
torch.tensor(output[44]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[45]),
|
||||
torch.tensor(output[46]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[47]),
|
||||
torch.tensor(output[48]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[49]),
|
||||
torch.tensor(output[50]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[51]),
|
||||
torch.tensor(output[52]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[53]),
|
||||
torch.tensor(output[54]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[55]),
|
||||
torch.tensor(output[56]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[57]),
|
||||
torch.tensor(output[58]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[59]),
|
||||
torch.tensor(output[60]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[61]),
|
||||
torch.tensor(output[62]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[63]),
|
||||
torch.tensor(output[64]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[65]),
|
||||
torch.tensor(output[66]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[67]),
|
||||
torch.tensor(output[68]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[69]),
|
||||
torch.tensor(output[70]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[71]),
|
||||
torch.tensor(output[72]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[73]),
|
||||
torch.tensor(output[74]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[75]),
|
||||
torch.tensor(output[76]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[77]),
|
||||
torch.tensor(output[78]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[79]),
|
||||
torch.tensor(output[80]),
|
||||
),
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
|
||||
@@ -5,7 +5,7 @@ from typing import List, Any
|
||||
from transformers import StoppingCriteria
|
||||
|
||||
|
||||
from brevitas_examples.llm.llm_quant.quantize import quantize_model
|
||||
from brevitas_examples.common.generative.quantize import quantize_model
|
||||
from brevitas_examples.llm.llm_quant.run_utils import get_model_impl
|
||||
|
||||
|
||||
@@ -37,7 +37,7 @@ class VisionModel(torch.nn.Module):
|
||||
dtype=torch.float32,
|
||||
weight_bit_width=weight_bit_width,
|
||||
weight_param_method="stats",
|
||||
weight_scale_precision="float",
|
||||
weight_scale_precision="float_scale",
|
||||
weight_quant_type="asym",
|
||||
weight_quant_granularity="per_group",
|
||||
weight_group_size=weight_group_size,
|
||||
@@ -52,7 +52,7 @@ class VisionModel(torch.nn.Module):
|
||||
dtype=torch.float32,
|
||||
weight_bit_width=weight_bit_width,
|
||||
weight_param_method="stats",
|
||||
weight_scale_precision="float",
|
||||
weight_scale_precision="float_scale",
|
||||
weight_quant_type="asym",
|
||||
weight_quant_granularity="per_group",
|
||||
weight_group_size=weight_group_size,
|
||||
@@ -93,7 +93,7 @@ class FirstLlamaModel(torch.nn.Module):
|
||||
dtype=torch.float32,
|
||||
weight_bit_width=weight_bit_width,
|
||||
weight_param_method="stats",
|
||||
weight_scale_precision="float",
|
||||
weight_scale_precision="float_scale",
|
||||
weight_quant_type="asym",
|
||||
weight_quant_granularity="per_group",
|
||||
weight_group_size=weight_group_size,
|
||||
@@ -157,7 +157,7 @@ class SecondLlamaModel(torch.nn.Module):
|
||||
dtype=torch.float32,
|
||||
weight_bit_width=weight_bit_width,
|
||||
weight_param_method="stats",
|
||||
weight_scale_precision="float",
|
||||
weight_scale_precision="float_scale",
|
||||
weight_quant_type="asym",
|
||||
weight_quant_granularity="per_group",
|
||||
weight_group_size=weight_group_size,
|
||||
|
||||
@@ -24,7 +24,9 @@ class FirstVicuna(torch.nn.Module):
|
||||
)
|
||||
print(f"[DEBUG] model_path : {model_path}")
|
||||
if precision in ["int4", "int8"]:
|
||||
from brevitas_examples.llm.llm_quant.quantize import quantize_model
|
||||
from brevitas_examples.common.generative.quantize import (
|
||||
quantize_model,
|
||||
)
|
||||
from brevitas_examples.llm.llm_quant.run_utils import (
|
||||
get_model_impl,
|
||||
)
|
||||
@@ -36,7 +38,7 @@ class FirstVicuna(torch.nn.Module):
|
||||
dtype=self.accumulates,
|
||||
weight_bit_width=weight_bit_width,
|
||||
weight_param_method="stats",
|
||||
weight_scale_precision="float",
|
||||
weight_scale_precision="float_scale",
|
||||
weight_quant_type="asym",
|
||||
weight_quant_granularity="per_group",
|
||||
weight_group_size=weight_group_size,
|
||||
@@ -79,7 +81,9 @@ class SecondVicuna7B(torch.nn.Module):
|
||||
)
|
||||
print(f"[DEBUG] model_path : {model_path}")
|
||||
if precision in ["int4", "int8"]:
|
||||
from brevitas_examples.llm.llm_quant.quantize import quantize_model
|
||||
from brevitas_examples.common.generative.quantize import (
|
||||
quantize_model,
|
||||
)
|
||||
from brevitas_examples.llm.llm_quant.run_utils import (
|
||||
get_model_impl,
|
||||
)
|
||||
@@ -91,7 +95,7 @@ class SecondVicuna7B(torch.nn.Module):
|
||||
dtype=self.accumulates,
|
||||
weight_bit_width=weight_bit_width,
|
||||
weight_param_method="stats",
|
||||
weight_scale_precision="float",
|
||||
weight_scale_precision="float_scale",
|
||||
weight_quant_type="asym",
|
||||
weight_quant_granularity="per_group",
|
||||
weight_group_size=weight_group_size,
|
||||
@@ -329,7 +333,9 @@ class SecondVicuna13B(torch.nn.Module):
|
||||
torch.float32 if accumulates == "fp32" else torch.float16
|
||||
)
|
||||
if precision in ["int4", "int8"]:
|
||||
from brevitas_examples.llm.llm_quant.quantize import quantize_model
|
||||
from brevitas_examples.common.generative.quantize import (
|
||||
quantize_model,
|
||||
)
|
||||
from brevitas_examples.llm.llm_quant.run_utils import (
|
||||
get_model_impl,
|
||||
)
|
||||
@@ -341,7 +347,7 @@ class SecondVicuna13B(torch.nn.Module):
|
||||
dtype=self.accumulates,
|
||||
weight_bit_width=weight_bit_width,
|
||||
weight_param_method="stats",
|
||||
weight_scale_precision="float",
|
||||
weight_scale_precision="float_scale",
|
||||
weight_quant_type="asym",
|
||||
weight_quant_granularity="per_group",
|
||||
weight_group_size=weight_group_size,
|
||||
@@ -627,7 +633,9 @@ class SecondVicuna70B(torch.nn.Module):
|
||||
)
|
||||
print(f"[DEBUG] model_path : {model_path}")
|
||||
if precision in ["int4", "int8"]:
|
||||
from brevitas_examples.llm.llm_quant.quantize import quantize_model
|
||||
from brevitas_examples.common.generative.quantize import (
|
||||
quantize_model,
|
||||
)
|
||||
from brevitas_examples.llm.llm_quant.run_utils import (
|
||||
get_model_impl,
|
||||
)
|
||||
@@ -639,7 +647,7 @@ class SecondVicuna70B(torch.nn.Module):
|
||||
dtype=self.accumulates,
|
||||
weight_bit_width=weight_bit_width,
|
||||
weight_param_method="stats",
|
||||
weight_scale_precision="float",
|
||||
weight_scale_precision="float_scale",
|
||||
weight_quant_type="asym",
|
||||
weight_quant_granularity="per_group",
|
||||
weight_group_size=weight_group_size,
|
||||
|
||||
@@ -24,7 +24,9 @@ class FirstVicunaGPU(torch.nn.Module):
|
||||
)
|
||||
print(f"[DEBUG] model_path : {model_path}")
|
||||
if precision in ["int4", "int8"]:
|
||||
from brevitas_examples.llm.llm_quant.quantize import quantize_model
|
||||
from brevitas_examples.common.generative.quantize import (
|
||||
quantize_model,
|
||||
)
|
||||
from brevitas_examples.llm.llm_quant.run_utils import (
|
||||
get_model_impl,
|
||||
)
|
||||
@@ -36,7 +38,7 @@ class FirstVicunaGPU(torch.nn.Module):
|
||||
dtype=self.accumulates,
|
||||
weight_bit_width=weight_bit_width,
|
||||
weight_param_method="stats",
|
||||
weight_scale_precision="float",
|
||||
weight_scale_precision="float_scale",
|
||||
weight_quant_type="asym",
|
||||
weight_quant_granularity="per_group",
|
||||
weight_group_size=weight_group_size,
|
||||
@@ -78,7 +80,9 @@ class SecondVicuna7BGPU(torch.nn.Module):
|
||||
)
|
||||
print(f"[DEBUG] model_path : {model_path}")
|
||||
if precision in ["int4", "int8"]:
|
||||
from brevitas_examples.llm.llm_quant.quantize import quantize_model
|
||||
from brevitas_examples.common.generative.quantize import (
|
||||
quantize_model,
|
||||
)
|
||||
from brevitas_examples.llm.llm_quant.run_utils import (
|
||||
get_model_impl,
|
||||
)
|
||||
@@ -90,7 +94,7 @@ class SecondVicuna7BGPU(torch.nn.Module):
|
||||
dtype=self.accumulates,
|
||||
weight_bit_width=weight_bit_width,
|
||||
weight_param_method="stats",
|
||||
weight_scale_precision="float",
|
||||
weight_scale_precision="float_scale",
|
||||
weight_quant_type="asym",
|
||||
weight_quant_granularity="per_group",
|
||||
weight_group_size=weight_group_size,
|
||||
@@ -327,7 +331,9 @@ class SecondVicuna13BGPU(torch.nn.Module):
|
||||
torch.float32 if accumulates == "fp32" else torch.float16
|
||||
)
|
||||
if precision in ["int4", "int8"]:
|
||||
from brevitas_examples.llm.llm_quant.quantize import quantize_model
|
||||
from brevitas_examples.common.generative.quantize import (
|
||||
quantize_model,
|
||||
)
|
||||
from brevitas_examples.llm.llm_quant.run_utils import (
|
||||
get_model_impl,
|
||||
)
|
||||
@@ -339,7 +345,7 @@ class SecondVicuna13BGPU(torch.nn.Module):
|
||||
dtype=self.accumulates,
|
||||
weight_bit_width=weight_bit_width,
|
||||
weight_param_method="stats",
|
||||
weight_scale_precision="float",
|
||||
weight_scale_precision="float_scale",
|
||||
weight_quant_type="asym",
|
||||
weight_quant_granularity="per_group",
|
||||
weight_group_size=weight_group_size,
|
||||
@@ -625,7 +631,9 @@ class SecondVicuna70BGPU(torch.nn.Module):
|
||||
)
|
||||
print(f"[DEBUG] model_path : {model_path}")
|
||||
if precision in ["int4", "int8"]:
|
||||
from brevitas_examples.llm.llm_quant.quantize import quantize_model
|
||||
from brevitas_examples.common.generative.quantize import (
|
||||
quantize_model,
|
||||
)
|
||||
from brevitas_examples.llm.llm_quant.run_utils import (
|
||||
get_model_impl,
|
||||
)
|
||||
@@ -637,7 +645,7 @@ class SecondVicuna70BGPU(torch.nn.Module):
|
||||
dtype=self.accumulates,
|
||||
weight_bit_width=weight_bit_width,
|
||||
weight_param_method="stats",
|
||||
weight_scale_precision="float",
|
||||
weight_scale_precision="float_scale",
|
||||
weight_quant_type="asym",
|
||||
weight_quant_granularity="per_group",
|
||||
weight_group_size=weight_group_size,
|
||||
|
||||
@@ -6,10 +6,10 @@ from apps.language_models.src.model_wrappers.falcon_sharded_model import (
|
||||
CompiledLNFEmbeddingLayer,
|
||||
LMHeadEmbeddingLayer,
|
||||
CompiledLMHeadEmbeddingLayer,
|
||||
DecoderLayer,
|
||||
EightDecoderLayer,
|
||||
CompiledDecoderLayer,
|
||||
CompiledEightDecoderLayer,
|
||||
FourWayShardingDecoderLayer,
|
||||
TwoWayShardingDecoderLayer,
|
||||
CompiledFourWayShardingDecoderLayer,
|
||||
CompiledTwoWayShardingDecoderLayer,
|
||||
ShardedFalconModel,
|
||||
)
|
||||
from apps.language_models.src.pipelines.SharkLLMBase import SharkLLMBase
|
||||
@@ -94,6 +94,13 @@ parser.add_argument(
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="Run model as sharded",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num_shards",
|
||||
type=int,
|
||||
default=4,
|
||||
choices=[2, 4],
|
||||
help="Number of shards.",
|
||||
)
|
||||
|
||||
|
||||
class ShardedFalcon(SharkLLMBase):
|
||||
@@ -122,6 +129,10 @@ class ShardedFalcon(SharkLLMBase):
|
||||
--hf_auth_token flag. You can ask for the access to the model
|
||||
here: https://huggingface.co/tiiuae/falcon-180B-chat."""
|
||||
)
|
||||
|
||||
if args.sharded and "180b" not in self.model_name:
|
||||
raise ValueError("Sharding supported only for Falcon-180B")
|
||||
|
||||
self.hf_auth_token = hf_auth_token
|
||||
self.max_padding_length = 100
|
||||
self.device = device
|
||||
@@ -131,7 +142,7 @@ class ShardedFalcon(SharkLLMBase):
|
||||
self.debug = debug
|
||||
self.tokenizer = self.get_tokenizer()
|
||||
self.src_model = self.get_src_model()
|
||||
self.shark_model = self.compile(compressed=args.compressed)
|
||||
self.shark_model = self.compile()
|
||||
|
||||
def get_tokenizer(self):
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
@@ -146,20 +157,17 @@ class ShardedFalcon(SharkLLMBase):
|
||||
def get_src_model(self):
|
||||
print("Loading src model: ", self.model_name)
|
||||
kwargs = {
|
||||
"torch_dtype": torch.float,
|
||||
"torch_dtype": torch.float32,
|
||||
"trust_remote_code": True,
|
||||
"token": self.hf_auth_token,
|
||||
}
|
||||
if self.precision == "int4":
|
||||
quantization_config = GPTQConfig(bits=4, disable_exllama=True)
|
||||
kwargs["quantization_config"] = quantization_config
|
||||
kwargs["load_gptq_on_cpu"] = True
|
||||
kwargs["device_map"] = "cpu"
|
||||
falcon_model = AutoModelForCausalLM.from_pretrained(
|
||||
self.hf_model_path, **kwargs
|
||||
)
|
||||
if self.precision == "int4":
|
||||
falcon_model = falcon_model.to(torch.float32)
|
||||
return falcon_model
|
||||
|
||||
def compile_layer(
|
||||
@@ -288,28 +296,14 @@ class ShardedFalcon(SharkLLMBase):
|
||||
|
||||
return shark_module, device_idx
|
||||
|
||||
def compile(self, compressed=False):
|
||||
def compile(self):
|
||||
sample_input_ids = torch.zeros([100], dtype=torch.int64)
|
||||
sample_attention_mask = torch.zeros(
|
||||
[1, 1, 100, 100], dtype=torch.float32
|
||||
)
|
||||
num_group_layers = 1
|
||||
if "7b" in self.model_name:
|
||||
num_in_features = 4544
|
||||
if compressed:
|
||||
num_group_layers = 8
|
||||
elif "40b" in self.model_name:
|
||||
num_in_features = 8192
|
||||
if compressed:
|
||||
num_group_layers = 15
|
||||
else:
|
||||
num_in_features = 14848
|
||||
sample_attention_mask = sample_attention_mask.to(dtype=torch.bool)
|
||||
if compressed:
|
||||
num_group_layers = 20
|
||||
|
||||
sample_attention_mask = torch.zeros([1, 1, 100, 100], dtype=torch.bool)
|
||||
num_group_layers = int(
|
||||
20 * (4 / args.num_shards)
|
||||
) # 4 is the number of default shards
|
||||
sample_hidden_states = torch.zeros(
|
||||
[1, 100, num_in_features], dtype=torch.float32
|
||||
[1, 100, 14848], dtype=torch.float32
|
||||
)
|
||||
|
||||
# Determine number of available devices
|
||||
@@ -319,6 +313,10 @@ class ShardedFalcon(SharkLLMBase):
|
||||
|
||||
haldriver = ireert.get_driver(self.device)
|
||||
num_devices = len(haldriver.query_available_devices())
|
||||
if num_devices < 2:
|
||||
raise ValueError(
|
||||
"Cannot run Falcon-180B on a single ROCM device."
|
||||
)
|
||||
|
||||
lm_head = LMHeadEmbeddingLayer(self.src_model.lm_head)
|
||||
print("Compiling Layer lm_head")
|
||||
@@ -326,7 +324,9 @@ class ShardedFalcon(SharkLLMBase):
|
||||
lm_head,
|
||||
[sample_hidden_states],
|
||||
"lm_head",
|
||||
device_idx=0 % num_devices if self.device == "rocm" else None,
|
||||
device_idx=(0 % num_devices) % args.num_shards
|
||||
if self.device == "rocm"
|
||||
else None,
|
||||
)
|
||||
shark_lm_head = CompiledLMHeadEmbeddingLayer(shark_lm_head)
|
||||
|
||||
@@ -338,7 +338,9 @@ class ShardedFalcon(SharkLLMBase):
|
||||
word_embedding,
|
||||
[sample_input_ids],
|
||||
"word_embeddings",
|
||||
device_idx=1 % num_devices if self.device == "rocm" else None,
|
||||
device_idx=(1 % num_devices) % args.num_shards
|
||||
if self.device == "rocm"
|
||||
else None,
|
||||
)
|
||||
shark_word_embedding = CompiledWordEmbeddingsLayer(
|
||||
shark_word_embedding
|
||||
@@ -350,7 +352,9 @@ class ShardedFalcon(SharkLLMBase):
|
||||
ln_f,
|
||||
[sample_hidden_states],
|
||||
"ln_f",
|
||||
device_idx=2 % num_devices if self.device == "rocm" else None,
|
||||
device_idx=(2 % num_devices) % args.num_shards
|
||||
if self.device == "rocm"
|
||||
else None,
|
||||
)
|
||||
shark_ln_f = CompiledLNFEmbeddingLayer(shark_ln_f)
|
||||
|
||||
@@ -360,24 +364,21 @@ class ShardedFalcon(SharkLLMBase):
|
||||
):
|
||||
device_idx = i % num_devices if self.device == "rocm" else None
|
||||
layer_id = i
|
||||
pytorch_class = DecoderLayer
|
||||
compiled_class = CompiledDecoderLayer
|
||||
if compressed:
|
||||
layer_id = (
|
||||
str(i * num_group_layers)
|
||||
+ "_"
|
||||
+ str((i + 1) * num_group_layers)
|
||||
)
|
||||
pytorch_class = EightDecoderLayer
|
||||
compiled_class = CompiledEightDecoderLayer
|
||||
layer_id = (
|
||||
str(i * num_group_layers)
|
||||
+ "_"
|
||||
+ str((i + 1) * num_group_layers)
|
||||
)
|
||||
pytorch_class = FourWayShardingDecoderLayer
|
||||
compiled_class = CompiledFourWayShardingDecoderLayer
|
||||
if args.num_shards == 2:
|
||||
pytorch_class = TwoWayShardingDecoderLayer
|
||||
compiled_class = CompiledTwoWayShardingDecoderLayer
|
||||
|
||||
print("Compiling Layer {}".format(layer_id))
|
||||
if compressed:
|
||||
layer_i = self.src_model.transformer.h[
|
||||
i * num_group_layers : (i + 1) * num_group_layers
|
||||
]
|
||||
else:
|
||||
layer_i = self.src_model.transformer.h[i]
|
||||
layer_i = self.src_model.transformer.h[
|
||||
i * num_group_layers : (i + 1) * num_group_layers
|
||||
]
|
||||
|
||||
pytorch_layer_i = pytorch_class(
|
||||
layer_i, args.falcon_variant_to_use
|
||||
@@ -388,13 +389,13 @@ class ShardedFalcon(SharkLLMBase):
|
||||
layer_id,
|
||||
device_idx=device_idx,
|
||||
)
|
||||
del shark_module
|
||||
shark_layer_i = compiled_class(
|
||||
layer_id,
|
||||
device_idx,
|
||||
args.falcon_variant_to_use,
|
||||
self.device,
|
||||
self.precision,
|
||||
shark_module,
|
||||
)
|
||||
shark_layers.append(shark_layer_i)
|
||||
|
||||
|
||||
@@ -132,7 +132,7 @@ import torch_mlir
|
||||
from torch_mlir.compiler_utils import run_pipeline_with_repro_report
|
||||
from typing import List, Tuple
|
||||
from io import BytesIO
|
||||
from brevitas_examples.llm.llm_quant.quantize import quantize_model
|
||||
from brevitas_examples.common.generative.quantize import quantize_model
|
||||
from brevitas_examples.llm.llm_quant.run_utils import get_model_impl
|
||||
|
||||
|
||||
|
||||
91
apps/shark_studio/api/llm.py
Normal file
91
apps/shark_studio/api/llm.py
Normal file
@@ -0,0 +1,91 @@
|
||||
from turbine_models.custom_models import stateless_llama
|
||||
from shark.iree_utils.compile_utils import get_iree_compiled_module
|
||||
from apps.shark_studio.api.utils import get_resource_path
|
||||
import iree.runtime as ireert
|
||||
import gc
|
||||
import torch
|
||||
|
||||
llm_model_map = {
|
||||
"llama2_7b": {
|
||||
"initializer": stateless_llama.export_transformer_model,
|
||||
"hf_model_name": "meta-llama/Llama-2-7b-chat-hf",
|
||||
"stop_token": 2,
|
||||
"max_tokens": 4096,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class LanguageModel:
|
||||
def __init__(
|
||||
self, model_name, hf_auth_token=None, device=None, precision="fp32"
|
||||
):
|
||||
print(llm_model_map[model_name])
|
||||
self.hf_model_name = llm_model_map[model_name]["hf_model_name"]
|
||||
self.torch_ir, self.tokenizer = llm_model_map[model_name][
|
||||
"initializer"
|
||||
](self.hf_model_name, hf_auth_token, compile_to="torch")
|
||||
self.tempfile_name = get_resource_path("llm.torch.tempfile")
|
||||
with open(self.tempfile_name, "w+") as f:
|
||||
f.write(self.torch_ir)
|
||||
del self.torch_ir
|
||||
gc.collect()
|
||||
|
||||
self.device = device
|
||||
self.precision = precision
|
||||
self.max_tokens = llm_model_map[model_name]["max_tokens"]
|
||||
self.iree_module_dict = None
|
||||
self.compile()
|
||||
|
||||
def compile(self) -> None:
|
||||
# this comes with keys: "vmfb", "config", and "temp_file_to_unlink".
|
||||
self.iree_module_dict = get_iree_compiled_module(
|
||||
self.tempfile_name, device=self.device, frontend="torch"
|
||||
)
|
||||
# TODO: delete the temp file
|
||||
|
||||
def chat(self, prompt):
|
||||
history = []
|
||||
for iter in range(self.max_tokens):
|
||||
input_tensor = self.tokenizer(
|
||||
prompt, return_tensors="pt"
|
||||
).input_ids
|
||||
device_inputs = [
|
||||
ireert.asdevicearray(
|
||||
self.iree_module_dict["config"], input_tensor
|
||||
)
|
||||
]
|
||||
if iter == 0:
|
||||
token = torch.tensor(
|
||||
self.iree_module_dict["vmfb"]["run_initialize"](
|
||||
*device_inputs
|
||||
).to_host()[0][0]
|
||||
)
|
||||
else:
|
||||
token = torch.tensor(
|
||||
self.iree_module_dict["vmfb"]["run_forward"](
|
||||
*device_inputs
|
||||
).to_host()[0][0]
|
||||
)
|
||||
|
||||
history.append(token)
|
||||
yield self.tokenizer.decode(history)
|
||||
|
||||
if token == llm_model_map["llama2_7b"]["stop_token"]:
|
||||
break
|
||||
|
||||
for i in range(len(history)):
|
||||
if type(history[i]) != int:
|
||||
history[i] = int(history[i])
|
||||
result_output = self.tokenizer.decode(history)
|
||||
yield result_output
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
lm = LanguageModel(
|
||||
"llama2_7b",
|
||||
hf_auth_token="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",
|
||||
device="cpu-task",
|
||||
)
|
||||
print("model loaded")
|
||||
for i in lm.chat("Hello, I am a robot."):
|
||||
print(i)
|
||||
14
apps/shark_studio/api/utils.py
Normal file
14
apps/shark_studio/api/utils.py
Normal file
@@ -0,0 +1,14 @@
|
||||
import os
|
||||
import sys
|
||||
|
||||
|
||||
def get_available_devices():
|
||||
return ["cpu-task"]
|
||||
|
||||
|
||||
def get_resource_path(relative_path):
|
||||
"""Get absolute path to resource, works for dev and for PyInstaller"""
|
||||
base_path = getattr(
|
||||
sys, "_MEIPASS", os.path.dirname(os.path.abspath(__file__))
|
||||
)
|
||||
return os.path.join(base_path, relative_path)
|
||||
428
apps/shark_studio/web/index.py
Normal file
428
apps/shark_studio/web/index.py
Normal file
@@ -0,0 +1,428 @@
|
||||
from multiprocessing import Process, freeze_support
|
||||
import os
|
||||
import sys
|
||||
import logging
|
||||
from ui.chat import chat_element
|
||||
|
||||
if sys.platform == "darwin":
|
||||
os.environ["DYLD_LIBRARY_PATH"] = "/usr/local/lib"
|
||||
# import before IREE to avoid MLIR library issues
|
||||
import torch_mlir
|
||||
|
||||
# import PIL, transformers, sentencepiece # ensures inclusion in pysintaller exe generation
|
||||
# from apps.stable_diffusion.src import args, clear_all
|
||||
# import apps.stable_diffusion.web.utils.global_obj as global_obj
|
||||
|
||||
|
||||
def launch_app(address):
|
||||
from tkinter import Tk
|
||||
import webview
|
||||
|
||||
window = Tk()
|
||||
|
||||
# get screen width and height of display and make it more reasonably
|
||||
# sized as we aren't making it full-screen or maximized
|
||||
width = int(window.winfo_screenwidth() * 0.81)
|
||||
height = int(window.winfo_screenheight() * 0.91)
|
||||
webview.create_window(
|
||||
"SHARK AI Studio",
|
||||
url=address,
|
||||
width=width,
|
||||
height=height,
|
||||
text_select=True,
|
||||
)
|
||||
webview.start(private_mode=False, storage_path=os.getcwd())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# if args.debug:
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
# required to do multiprocessing in a pyinstaller freeze
|
||||
freeze_support()
|
||||
# if args.api or "api" in args.ui.split(","):
|
||||
# from apps.stable_diffusion.web.ui import (
|
||||
# txt2img_api,
|
||||
# img2img_api,
|
||||
# upscaler_api,
|
||||
# inpaint_api,
|
||||
# outpaint_api,
|
||||
# llm_chat_api,
|
||||
# )
|
||||
#
|
||||
# from fastapi import FastAPI, APIRouter
|
||||
# import uvicorn
|
||||
#
|
||||
# # init global sd pipeline and config
|
||||
# global_obj._init()
|
||||
#
|
||||
# app = FastAPI()
|
||||
# app.add_api_route("/sdapi/v1/txt2img", txt2img_api, methods=["post"])
|
||||
# app.add_api_route("/sdapi/v1/img2img", img2img_api, methods=["post"])
|
||||
# app.add_api_route("/sdapi/v1/inpaint", inpaint_api, methods=["post"])
|
||||
# app.add_api_route("/sdapi/v1/outpaint", outpaint_api, methods=["post"])
|
||||
# app.add_api_route("/sdapi/v1/upscaler", upscaler_api, methods=["post"])
|
||||
#
|
||||
# # chat APIs needed for compatibility with multiple extensions using OpenAI API
|
||||
# app.add_api_route(
|
||||
# "/v1/chat/completions", llm_chat_api, methods=["post"]
|
||||
# )
|
||||
# app.add_api_route("/v1/completions", llm_chat_api, methods=["post"])
|
||||
# app.add_api_route("/chat/completions", llm_chat_api, methods=["post"])
|
||||
# app.add_api_route("/completions", llm_chat_api, methods=["post"])
|
||||
# app.add_api_route(
|
||||
# "/v1/engines/codegen/completions", llm_chat_api, methods=["post"]
|
||||
# )
|
||||
# app.include_router(APIRouter())
|
||||
# uvicorn.run(app, host="0.0.0.0", port=args.server_port)
|
||||
# sys.exit(0)
|
||||
#
|
||||
# Setup to use shark_tmp for gradio's temporary image files and clear any
|
||||
# existing temporary images there if they exist. Then we can import gradio.
|
||||
# It has to be in this order or gradio ignores what we've set up.
|
||||
# from apps.stable_diffusion.web.utils.gradio_configs import (
|
||||
# config_gradio_tmp_imgs_folder,
|
||||
# )
|
||||
|
||||
# config_gradio_tmp_imgs_folder()
|
||||
import gradio as gr
|
||||
|
||||
# Create custom models folders if they don't exist
|
||||
# from apps.stable_diffusion.web.ui.utils import create_custom_models_folders
|
||||
|
||||
# create_custom_models_folders()
|
||||
|
||||
def resource_path(relative_path):
|
||||
"""Get absolute path to resource, works for dev and for PyInstaller"""
|
||||
base_path = getattr(
|
||||
sys, "_MEIPASS", os.path.dirname(os.path.abspath(__file__))
|
||||
)
|
||||
return os.path.join(base_path, relative_path)
|
||||
|
||||
dark_theme = resource_path("ui/css/sd_dark_theme.css")
|
||||
|
||||
# from apps.stable_diffusion.web.ui import (
|
||||
# txt2img_web,
|
||||
# txt2img_custom_model,
|
||||
# txt2img_gallery,
|
||||
# txt2img_png_info_img,
|
||||
# txt2img_status,
|
||||
# txt2img_sendto_img2img,
|
||||
# txt2img_sendto_inpaint,
|
||||
# txt2img_sendto_outpaint,
|
||||
# txt2img_sendto_upscaler,
|
||||
## h2ogpt_upload,
|
||||
## h2ogpt_web,
|
||||
# img2img_web,
|
||||
# img2img_custom_model,
|
||||
# img2img_gallery,
|
||||
# img2img_init_image,
|
||||
# img2img_status,
|
||||
# img2img_sendto_inpaint,
|
||||
# img2img_sendto_outpaint,
|
||||
# img2img_sendto_upscaler,
|
||||
# inpaint_web,
|
||||
# inpaint_custom_model,
|
||||
# inpaint_gallery,
|
||||
# inpaint_init_image,
|
||||
# inpaint_status,
|
||||
# inpaint_sendto_img2img,
|
||||
# inpaint_sendto_outpaint,
|
||||
# inpaint_sendto_upscaler,
|
||||
# outpaint_web,
|
||||
# outpaint_custom_model,
|
||||
# outpaint_gallery,
|
||||
# outpaint_init_image,
|
||||
# outpaint_status,
|
||||
# outpaint_sendto_img2img,
|
||||
# outpaint_sendto_inpaint,
|
||||
# outpaint_sendto_upscaler,
|
||||
# upscaler_web,
|
||||
# upscaler_custom_model,
|
||||
# upscaler_gallery,
|
||||
# upscaler_init_image,
|
||||
# upscaler_status,
|
||||
# upscaler_sendto_img2img,
|
||||
# upscaler_sendto_inpaint,
|
||||
# upscaler_sendto_outpaint,
|
||||
## lora_train_web,
|
||||
## model_web,
|
||||
## model_config_web,
|
||||
# hf_models,
|
||||
# modelmanager_sendto_txt2img,
|
||||
# modelmanager_sendto_img2img,
|
||||
# modelmanager_sendto_inpaint,
|
||||
# modelmanager_sendto_outpaint,
|
||||
# modelmanager_sendto_upscaler,
|
||||
# stablelm_chat,
|
||||
# minigpt4_web,
|
||||
# outputgallery_web,
|
||||
# outputgallery_tab_select,
|
||||
# outputgallery_watch,
|
||||
# outputgallery_filename,
|
||||
# outputgallery_sendto_txt2img,
|
||||
# outputgallery_sendto_img2img,
|
||||
# outputgallery_sendto_inpaint,
|
||||
# outputgallery_sendto_outpaint,
|
||||
# outputgallery_sendto_upscaler,
|
||||
# )
|
||||
|
||||
# init global sd pipeline and config
|
||||
# global_obj._init()
|
||||
|
||||
def register_button_click(button, selectedid, inputs, outputs):
|
||||
button.click(
|
||||
lambda x: (
|
||||
x[0]["name"] if len(x) != 0 else None,
|
||||
gr.Tabs.update(selected=selectedid),
|
||||
),
|
||||
inputs,
|
||||
outputs,
|
||||
)
|
||||
|
||||
def register_modelmanager_button(button, selectedid, inputs, outputs):
|
||||
button.click(
|
||||
lambda x: (
|
||||
"None",
|
||||
x,
|
||||
gr.Tabs.update(selected=selectedid),
|
||||
),
|
||||
inputs,
|
||||
outputs,
|
||||
)
|
||||
|
||||
def register_outputgallery_button(button, selectedid, inputs, outputs):
|
||||
button.click(
|
||||
lambda x: (
|
||||
x,
|
||||
gr.Tabs.update(selected=selectedid),
|
||||
),
|
||||
inputs,
|
||||
outputs,
|
||||
)
|
||||
|
||||
with gr.Blocks(
|
||||
css=dark_theme, analytics_enabled=False, title="Stable Diffusion"
|
||||
) as sd_web:
|
||||
with gr.Tabs() as tabs:
|
||||
# NOTE: If adding, removing, or re-ordering tabs, make sure that they
|
||||
# have a unique id that doesn't clash with any of the other tabs,
|
||||
# and that the order in the code here is the order they should
|
||||
# appear in the ui, as the id value doesn't determine the order.
|
||||
|
||||
# Where possible, avoid changing the id of any tab that is the
|
||||
# destination of one of the 'send to' buttons. If you do have to change
|
||||
# that id, make sure you update the relevant register_button_click calls
|
||||
# further down with the new id.
|
||||
# with gr.TabItem(label="Text-to-Image", id=0):
|
||||
# txt2img_web.render()
|
||||
# with gr.TabItem(label="Image-to-Image", id=1):
|
||||
# img2img_web.render()
|
||||
# with gr.TabItem(label="Inpainting", id=2):
|
||||
# inpaint_web.render()
|
||||
# with gr.TabItem(label="Outpainting", id=3):
|
||||
# outpaint_web.render()
|
||||
# with gr.TabItem(label="Upscaler", id=4):
|
||||
# upscaler_web.render()
|
||||
# if args.output_gallery:
|
||||
# with gr.TabItem(label="Output Gallery", id=5) as og_tab:
|
||||
# outputgallery_web.render()
|
||||
|
||||
# # extra output gallery configuration
|
||||
# outputgallery_tab_select(og_tab.select)
|
||||
# outputgallery_watch(
|
||||
# [
|
||||
# txt2img_status,
|
||||
# img2img_status,
|
||||
# inpaint_status,
|
||||
# outpaint_status,
|
||||
# upscaler_status,
|
||||
# ]
|
||||
# )
|
||||
## with gr.TabItem(label="Model Manager", id=6):
|
||||
## model_web.render()
|
||||
## with gr.TabItem(label="LoRA Training (Experimental)", id=7):
|
||||
## lora_train_web.render()
|
||||
with gr.TabItem(label="Chat Bot", id=0):
|
||||
chat_element.render()
|
||||
## with gr.TabItem(
|
||||
## label="Generate Sharding Config (Experimental)", id=9
|
||||
## ):
|
||||
## model_config_web.render()
|
||||
# with gr.TabItem(label="MultiModal (Experimental)", id=10):
|
||||
# minigpt4_web.render()
|
||||
# with gr.TabItem(label="DocuChat Upload", id=11):
|
||||
# h2ogpt_upload.render()
|
||||
# with gr.TabItem(label="DocuChat(Experimental)", id=12):
|
||||
# h2ogpt_web.render()
|
||||
|
||||
# send to buttons
|
||||
# register_button_click(
|
||||
# txt2img_sendto_img2img,
|
||||
# 1,
|
||||
# [txt2img_gallery],
|
||||
# [img2img_init_image, tabs],
|
||||
# )
|
||||
# register_button_click(
|
||||
# txt2img_sendto_inpaint,
|
||||
# 2,
|
||||
# [txt2img_gallery],
|
||||
# [inpaint_init_image, tabs],
|
||||
# )
|
||||
# register_button_click(
|
||||
# txt2img_sendto_outpaint,
|
||||
# 3,
|
||||
# [txt2img_gallery],
|
||||
# [outpaint_init_image, tabs],
|
||||
# )
|
||||
# register_button_click(
|
||||
# txt2img_sendto_upscaler,
|
||||
# 4,
|
||||
# [txt2img_gallery],
|
||||
# [upscaler_init_image, tabs],
|
||||
# )
|
||||
# register_button_click(
|
||||
# img2img_sendto_inpaint,
|
||||
# 2,
|
||||
# [img2img_gallery],
|
||||
# [inpaint_init_image, tabs],
|
||||
# )
|
||||
# register_button_click(
|
||||
# img2img_sendto_outpaint,
|
||||
# 3,
|
||||
# [img2img_gallery],
|
||||
# [outpaint_init_image, tabs],
|
||||
# )
|
||||
# register_button_click(
|
||||
# img2img_sendto_upscaler,
|
||||
# 4,
|
||||
# [img2img_gallery],
|
||||
# [upscaler_init_image, tabs],
|
||||
# )
|
||||
# register_button_click(
|
||||
# inpaint_sendto_img2img,
|
||||
# 1,
|
||||
# [inpaint_gallery],
|
||||
# [img2img_init_image, tabs],
|
||||
# )
|
||||
# register_button_click(
|
||||
# inpaint_sendto_outpaint,
|
||||
# 3,
|
||||
# [inpaint_gallery],
|
||||
# [outpaint_init_image, tabs],
|
||||
# )
|
||||
# register_button_click(
|
||||
# inpaint_sendto_upscaler,
|
||||
# 4,
|
||||
# [inpaint_gallery],
|
||||
# [upscaler_init_image, tabs],
|
||||
# )
|
||||
# register_button_click(
|
||||
# outpaint_sendto_img2img,
|
||||
# 1,
|
||||
# [outpaint_gallery],
|
||||
# [img2img_init_image, tabs],
|
||||
# )
|
||||
# register_button_click(
|
||||
# outpaint_sendto_inpaint,
|
||||
# 2,
|
||||
# [outpaint_gallery],
|
||||
# [inpaint_init_image, tabs],
|
||||
# )
|
||||
# register_button_click(
|
||||
# outpaint_sendto_upscaler,
|
||||
# 4,
|
||||
# [outpaint_gallery],
|
||||
# [upscaler_init_image, tabs],
|
||||
# )
|
||||
# register_button_click(
|
||||
# upscaler_sendto_img2img,
|
||||
# 1,
|
||||
# [upscaler_gallery],
|
||||
# [img2img_init_image, tabs],
|
||||
# )
|
||||
# register_button_click(
|
||||
# upscaler_sendto_inpaint,
|
||||
# 2,
|
||||
# [upscaler_gallery],
|
||||
# [inpaint_init_image, tabs],
|
||||
# )
|
||||
# register_button_click(
|
||||
# upscaler_sendto_outpaint,
|
||||
# 3,
|
||||
# [upscaler_gallery],
|
||||
# [outpaint_init_image, tabs],
|
||||
# )
|
||||
# if args.output_gallery:
|
||||
# register_outputgallery_button(
|
||||
# outputgallery_sendto_txt2img,
|
||||
# 0,
|
||||
# [outputgallery_filename],
|
||||
# [txt2img_png_info_img, tabs],
|
||||
# )
|
||||
# register_outputgallery_button(
|
||||
# outputgallery_sendto_img2img,
|
||||
# 1,
|
||||
# [outputgallery_filename],
|
||||
# [img2img_init_image, tabs],
|
||||
# )
|
||||
# register_outputgallery_button(
|
||||
# outputgallery_sendto_inpaint,
|
||||
# 2,
|
||||
# [outputgallery_filename],
|
||||
# [inpaint_init_image, tabs],
|
||||
# )
|
||||
# register_outputgallery_button(
|
||||
# outputgallery_sendto_outpaint,
|
||||
# 3,
|
||||
# [outputgallery_filename],
|
||||
# [outpaint_init_image, tabs],
|
||||
# )
|
||||
# register_outputgallery_button(
|
||||
# outputgallery_sendto_upscaler,
|
||||
# 4,
|
||||
# [outputgallery_filename],
|
||||
# [upscaler_init_image, tabs],
|
||||
# )
|
||||
# register_modelmanager_button(
|
||||
# modelmanager_sendto_txt2img,
|
||||
# 0,
|
||||
# [hf_models],
|
||||
# [txt2img_custom_model, tabs],
|
||||
# )
|
||||
# register_modelmanager_button(
|
||||
# modelmanager_sendto_img2img,
|
||||
# 1,
|
||||
# [hf_models],
|
||||
# [img2img_custom_model, tabs],
|
||||
# )
|
||||
# register_modelmanager_button(
|
||||
# modelmanager_sendto_inpaint,
|
||||
# 2,
|
||||
# [hf_models],
|
||||
# [inpaint_custom_model, tabs],
|
||||
# )
|
||||
# register_modelmanager_button(
|
||||
# modelmanager_sendto_outpaint,
|
||||
# 3,
|
||||
# [hf_models],
|
||||
# [outpaint_custom_model, tabs],
|
||||
# )
|
||||
# register_modelmanager_button(
|
||||
# modelmanager_sendto_upscaler,
|
||||
# 4,
|
||||
# [hf_models],
|
||||
# [upscaler_custom_model, tabs],
|
||||
# )
|
||||
|
||||
sd_web.queue()
|
||||
# if args.ui == "app":
|
||||
# t = Process(
|
||||
# target=launch_app, args=[f"http://localhost:{args.server_port}"]
|
||||
# )
|
||||
# t.start()
|
||||
sd_web.launch(
|
||||
share=True,
|
||||
inbrowser=True,
|
||||
server_name="0.0.0.0",
|
||||
server_port=11911, # args.server_port,
|
||||
)
|
||||
0
apps/shark_studio/web/ui/__init__.py
Normal file
0
apps/shark_studio/web/ui/__init__.py
Normal file
517
apps/shark_studio/web/ui/chat.py
Normal file
517
apps/shark_studio/web/ui/chat.py
Normal file
@@ -0,0 +1,517 @@
|
||||
import gradio as gr
|
||||
import os
|
||||
from pathlib import Path
|
||||
from datetime import datetime as dt
|
||||
import json
|
||||
import sys
|
||||
from apps.shark_studio.api.utils import (
|
||||
get_available_devices,
|
||||
)
|
||||
from apps.shark_studio.api.llm import (
|
||||
llm_model_map,
|
||||
LanguageModel,
|
||||
)
|
||||
|
||||
|
||||
def user(message, history):
|
||||
# Append the user's message to the conversation history
|
||||
return "", history + [[message, ""]]
|
||||
|
||||
|
||||
language_model = None
|
||||
|
||||
|
||||
# NOTE: Each `model_name` should have its own start message
|
||||
start_message = {
|
||||
"llama2_7b": (
|
||||
"You are a helpful, respectful and honest assistant. Always answer "
|
||||
"as helpfully as possible, while being safe. Your answers should not "
|
||||
"include any harmful, unethical, racist, sexist, toxic, dangerous, or "
|
||||
"illegal content. Please ensure that your responses are socially "
|
||||
"unbiased and positive in nature. If a question does not make any "
|
||||
"sense, or is not factually coherent, explain why instead of "
|
||||
"answering something not correct. If you don't know the answer "
|
||||
"to a question, please don't share false information."
|
||||
),
|
||||
"llama2_13b": (
|
||||
"You are a helpful, respectful and honest assistant. Always answer "
|
||||
"as helpfully as possible, while being safe. Your answers should not "
|
||||
"include any harmful, unethical, racist, sexist, toxic, dangerous, or "
|
||||
"illegal content. Please ensure that your responses are socially "
|
||||
"unbiased and positive in nature. If a question does not make any "
|
||||
"sense, or is not factually coherent, explain why instead of "
|
||||
"answering something not correct. If you don't know the answer "
|
||||
"to a question, please don't share false information."
|
||||
),
|
||||
"llama2_70b": (
|
||||
"You are a helpful, respectful and honest assistant. Always answer "
|
||||
"as helpfully as possible, while being safe. Your answers should not "
|
||||
"include any harmful, unethical, racist, sexist, toxic, dangerous, or "
|
||||
"illegal content. Please ensure that your responses are socially "
|
||||
"unbiased and positive in nature. If a question does not make any "
|
||||
"sense, or is not factually coherent, explain why instead of "
|
||||
"answering something not correct. If you don't know the answer "
|
||||
"to a question, please don't share false information."
|
||||
),
|
||||
"vicuna": (
|
||||
"A chat between a curious user and an artificial intelligence "
|
||||
"assistant. The assistant gives helpful, detailed, and "
|
||||
"polite answers to the user's questions.\n"
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
def create_prompt(model_name, history, prompt_prefix):
|
||||
return ""
|
||||
system_message = ""
|
||||
if prompt_prefix:
|
||||
system_message = start_message[model_name]
|
||||
|
||||
if "llama2" in model_name:
|
||||
B_INST, E_INST = "[INST]", "[/INST]"
|
||||
B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
|
||||
conversation = "".join(
|
||||
[f"{B_INST} {item[0]} {E_INST} {item[1]} " for item in history[1:]]
|
||||
)
|
||||
if prompt_prefix:
|
||||
msg = f"{B_INST} {B_SYS}{system_message}{E_SYS}{history[0][0]} {E_INST} {history[0][1]} {conversation}"
|
||||
else:
|
||||
msg = f"{B_INST} {history[0][0]} {E_INST} {history[0][1]} {conversation}"
|
||||
elif model_name in ["vicuna"]:
|
||||
conversation = "".join(
|
||||
[
|
||||
"".join(["<|USER|>" + item[0], "<|ASSISTANT|>" + item[1]])
|
||||
for item in history
|
||||
]
|
||||
)
|
||||
msg = system_message + conversation
|
||||
msg = msg.strip()
|
||||
else:
|
||||
conversation = "".join(
|
||||
["".join([item[0], item[1]]) for item in history]
|
||||
)
|
||||
msg = system_message + conversation
|
||||
msg = msg.strip()
|
||||
return msg
|
||||
|
||||
|
||||
def get_default_config():
|
||||
return False
|
||||
import torch
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
hf_model_path = "TheBloke/vicuna-7B-1.1-HF"
|
||||
tokenizer = AutoTokenizer.from_pretrained(hf_model_path, use_fast=False)
|
||||
compilation_prompt = "".join(["0" for _ in range(17)])
|
||||
compilation_input_ids = tokenizer(
|
||||
compilation_prompt,
|
||||
return_tensors="pt",
|
||||
).input_ids
|
||||
compilation_input_ids = torch.tensor(compilation_input_ids).reshape(
|
||||
[1, 19]
|
||||
)
|
||||
firstVicunaCompileInput = (compilation_input_ids,)
|
||||
from apps.language_models.src.model_wrappers.vicuna_model import (
|
||||
CombinedModel,
|
||||
)
|
||||
from shark.shark_generate_model_config import GenerateConfigFile
|
||||
|
||||
model = CombinedModel()
|
||||
c = GenerateConfigFile(model, 1, ["gpu_id"], firstVicunaCompileInput)
|
||||
c.split_into_layers()
|
||||
|
||||
|
||||
# model_vmfb_key = ""
|
||||
|
||||
|
||||
def chat_fn(
|
||||
prompt_prefix,
|
||||
history,
|
||||
model,
|
||||
device,
|
||||
precision,
|
||||
download_vmfb,
|
||||
config_file,
|
||||
cli=False,
|
||||
progress=gr.Progress(),
|
||||
):
|
||||
global language_model
|
||||
if language_model is None:
|
||||
language_model = LanguageModel(
|
||||
model, device=device, precision=precision
|
||||
)
|
||||
|
||||
language_model.chat(prompt_prefix)
|
||||
return "", ""
|
||||
global past_key_values
|
||||
global model_vmfb_key
|
||||
|
||||
device_id = None
|
||||
model_name, model_path = list(map(str.strip, model.split("=>")))
|
||||
if "cuda" in device:
|
||||
device = "cuda"
|
||||
elif "sync" in device:
|
||||
device = "cpu-sync"
|
||||
elif "task" in device:
|
||||
device = "cpu-task"
|
||||
elif "vulkan" in device:
|
||||
device_id = int(device.split("://")[1])
|
||||
device = "vulkan"
|
||||
elif "rocm" in device:
|
||||
device = "rocm"
|
||||
else:
|
||||
print("unrecognized device")
|
||||
|
||||
from apps.language_models.scripts.vicuna import ShardedVicuna
|
||||
from apps.language_models.scripts.vicuna import UnshardedVicuna
|
||||
from apps.stable_diffusion.src import args
|
||||
|
||||
new_model_vmfb_key = f"{model_name}#{model_path}#{device}#{device_id}#{precision}#{download_vmfb}"
|
||||
if vicuna_model is None or new_model_vmfb_key != model_vmfb_key:
|
||||
model_vmfb_key = new_model_vmfb_key
|
||||
max_toks = 128 if model_name == "codegen" else 512
|
||||
|
||||
# get iree flags that need to be overridden, from commandline args
|
||||
_extra_args = []
|
||||
# vulkan target triple
|
||||
vulkan_target_triple = args.iree_vulkan_target_triple
|
||||
from shark.iree_utils.vulkan_utils import (
|
||||
get_all_vulkan_devices,
|
||||
get_vulkan_target_triple,
|
||||
)
|
||||
|
||||
if device == "vulkan":
|
||||
vulkaninfo_list = get_all_vulkan_devices()
|
||||
if vulkan_target_triple == "":
|
||||
# We already have the device_id extracted via WebUI, so we directly use
|
||||
# that to find the target triple.
|
||||
vulkan_target_triple = get_vulkan_target_triple(
|
||||
vulkaninfo_list[device_id]
|
||||
)
|
||||
_extra_args.append(
|
||||
f"-iree-vulkan-target-triple={vulkan_target_triple}"
|
||||
)
|
||||
if "rdna" in vulkan_target_triple:
|
||||
flags_to_add = [
|
||||
"--iree-spirv-index-bits=64",
|
||||
]
|
||||
_extra_args = _extra_args + flags_to_add
|
||||
|
||||
if device_id is None:
|
||||
id = 0
|
||||
for device in vulkaninfo_list:
|
||||
target_triple = get_vulkan_target_triple(
|
||||
vulkaninfo_list[id]
|
||||
)
|
||||
if target_triple == vulkan_target_triple:
|
||||
device_id = id
|
||||
break
|
||||
id += 1
|
||||
|
||||
assert (
|
||||
device_id
|
||||
), f"no vulkan hardware for target-triple '{vulkan_target_triple}' exists"
|
||||
print(f"Will use vulkan target triple : {vulkan_target_triple}")
|
||||
|
||||
elif "rocm" in device:
|
||||
# add iree rocm flags
|
||||
_extra_args.append(
|
||||
f"--iree-rocm-target-chip={args.iree_rocm_target_chip}"
|
||||
)
|
||||
print(f"extra args = {_extra_args}")
|
||||
|
||||
if model_name == "vicuna4":
|
||||
vicuna_model = ShardedVicuna(
|
||||
model_name,
|
||||
hf_model_path=model_path,
|
||||
device=device,
|
||||
precision=precision,
|
||||
max_num_tokens=max_toks,
|
||||
compressed=True,
|
||||
extra_args_cmd=_extra_args,
|
||||
)
|
||||
else:
|
||||
# if config_file is None:
|
||||
vicuna_model = UnshardedVicuna(
|
||||
model_name,
|
||||
hf_model_path=model_path,
|
||||
hf_auth_token=args.hf_auth_token,
|
||||
device=device,
|
||||
vulkan_target_triple=vulkan_target_triple,
|
||||
precision=precision,
|
||||
max_num_tokens=max_toks,
|
||||
download_vmfb=download_vmfb,
|
||||
load_mlir_from_shark_tank=True,
|
||||
extra_args_cmd=_extra_args,
|
||||
device_id=device_id,
|
||||
)
|
||||
|
||||
if vicuna_model is None:
|
||||
sys.exit("Unable to instantiate the model object, exiting.")
|
||||
|
||||
prompt = create_prompt(model_name, history, prompt_prefix)
|
||||
|
||||
partial_text = ""
|
||||
token_count = 0
|
||||
total_time_ms = 0.001 # In order to avoid divide by zero error
|
||||
prefill_time = 0
|
||||
is_first = True
|
||||
for text, msg, exec_time in progress.tqdm(
|
||||
vicuna_model.generate(prompt, cli=cli),
|
||||
desc="generating response",
|
||||
):
|
||||
if msg is None:
|
||||
if is_first:
|
||||
prefill_time = exec_time
|
||||
is_first = False
|
||||
else:
|
||||
total_time_ms += exec_time
|
||||
token_count += 1
|
||||
partial_text += text + " "
|
||||
history[-1][1] = partial_text
|
||||
yield history, f"Prefill: {prefill_time:.2f}"
|
||||
elif "formatted" in msg:
|
||||
history[-1][1] = text
|
||||
tokens_per_sec = (token_count / total_time_ms) * 1000
|
||||
yield history, f"Prefill: {prefill_time:.2f} seconds\n Decode: {tokens_per_sec:.2f} tokens/sec"
|
||||
else:
|
||||
sys.exit(
|
||||
"unexpected message from the vicuna generate call, exiting."
|
||||
)
|
||||
|
||||
return history, ""
|
||||
|
||||
|
||||
def llm_chat_api(InputData: dict):
|
||||
return None
|
||||
print(f"Input keys : {InputData.keys()}")
|
||||
# print(f"model : {InputData['model']}")
|
||||
is_chat_completion_api = (
|
||||
"messages" in InputData.keys()
|
||||
) # else it is the legacy `completion` api
|
||||
# For Debugging input data from API
|
||||
# if is_chat_completion_api:
|
||||
# print(f"message -> role : {InputData['messages'][0]['role']}")
|
||||
# print(f"message -> content : {InputData['messages'][0]['content']}")
|
||||
# else:
|
||||
# print(f"prompt : {InputData['prompt']}")
|
||||
# print(f"max_tokens : {InputData['max_tokens']}") # Default to 128 for now
|
||||
global vicuna_model
|
||||
model_name = (
|
||||
InputData["model"] if "model" in InputData.keys() else "codegen"
|
||||
)
|
||||
model_path = llm_model_map[model_name]
|
||||
device = "cpu-task"
|
||||
precision = "fp16"
|
||||
max_toks = (
|
||||
None
|
||||
if "max_tokens" not in InputData.keys()
|
||||
else InputData["max_tokens"]
|
||||
)
|
||||
if max_toks is None:
|
||||
max_toks = 128 if model_name == "codegen" else 512
|
||||
|
||||
# make it working for codegen first
|
||||
from apps.language_models.scripts.vicuna import (
|
||||
UnshardedVicuna,
|
||||
)
|
||||
|
||||
device_id = None
|
||||
if vicuna_model == 0:
|
||||
if "cuda" in device:
|
||||
device = "cuda"
|
||||
elif "sync" in device:
|
||||
device = "cpu-sync"
|
||||
elif "task" in device:
|
||||
device = "cpu-task"
|
||||
elif "vulkan" in device:
|
||||
device_id = int(device.split("://")[1])
|
||||
device = "vulkan"
|
||||
else:
|
||||
print("unrecognized device")
|
||||
|
||||
vicuna_model = UnshardedVicuna(
|
||||
model_name,
|
||||
hf_model_path=model_path,
|
||||
device=device,
|
||||
precision=precision,
|
||||
max_num_tokens=max_toks,
|
||||
download_vmfb=True,
|
||||
load_mlir_from_shark_tank=True,
|
||||
device_id=device_id,
|
||||
)
|
||||
|
||||
# TODO: add role dict for different models
|
||||
if is_chat_completion_api:
|
||||
# TODO: add funtionality for multiple messages
|
||||
prompt = create_prompt(
|
||||
model_name, [(InputData["messages"][0]["content"], "")]
|
||||
)
|
||||
else:
|
||||
prompt = InputData["prompt"]
|
||||
print("prompt = ", prompt)
|
||||
|
||||
res = vicuna_model.generate(prompt)
|
||||
res_op = None
|
||||
for op in res:
|
||||
res_op = op
|
||||
|
||||
if is_chat_completion_api:
|
||||
choices = [
|
||||
{
|
||||
"index": 0,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": res_op, # since we are yeilding the result
|
||||
},
|
||||
"finish_reason": "stop", # or length
|
||||
}
|
||||
]
|
||||
else:
|
||||
choices = [
|
||||
{
|
||||
"text": res_op,
|
||||
"index": 0,
|
||||
"logprobs": None,
|
||||
"finish_reason": "stop", # or length
|
||||
}
|
||||
]
|
||||
end_time = dt.now().strftime("%Y%m%d%H%M%S%f")
|
||||
return {
|
||||
"id": end_time,
|
||||
"object": "chat.completion"
|
||||
if is_chat_completion_api
|
||||
else "text_completion",
|
||||
"created": int(end_time),
|
||||
"choices": choices,
|
||||
}
|
||||
|
||||
|
||||
def view_json_file(file_obj):
|
||||
content = ""
|
||||
with open(file_obj.name, "r") as fopen:
|
||||
content = fopen.read()
|
||||
return content
|
||||
|
||||
|
||||
with gr.Blocks(title="Chat") as chat_element:
|
||||
with gr.Row():
|
||||
model_choices = list(llm_model_map.keys())
|
||||
model = gr.Dropdown(
|
||||
label="Select Model",
|
||||
value=model_choices[0],
|
||||
choices=model_choices,
|
||||
allow_custom_value=True,
|
||||
)
|
||||
supported_devices = get_available_devices()
|
||||
enabled = True
|
||||
if len(supported_devices) == 0:
|
||||
supported_devices = ["cpu-task"]
|
||||
supported_devices = [x for x in supported_devices if "sync" not in x]
|
||||
device = gr.Dropdown(
|
||||
label="Device",
|
||||
value=supported_devices[0],
|
||||
choices=supported_devices,
|
||||
interactive=enabled,
|
||||
allow_custom_value=True,
|
||||
)
|
||||
precision = gr.Radio(
|
||||
label="Precision",
|
||||
value="int4",
|
||||
choices=[
|
||||
# "int4",
|
||||
# "int8",
|
||||
# "fp16",
|
||||
"fp32",
|
||||
],
|
||||
visible=False,
|
||||
)
|
||||
tokens_time = gr.Textbox(label="Tokens generated per second")
|
||||
with gr.Column():
|
||||
download_vmfb = gr.Checkbox(
|
||||
label="Download vmfb from Shark tank if available",
|
||||
value=True,
|
||||
interactive=True,
|
||||
)
|
||||
prompt_prefix = gr.Checkbox(
|
||||
label="Add System Prompt",
|
||||
value=False,
|
||||
interactive=True,
|
||||
)
|
||||
|
||||
chatbot = gr.Chatbot(height=500)
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
msg = gr.Textbox(
|
||||
label="Chat Message Box",
|
||||
placeholder="Chat Message Box",
|
||||
show_label=False,
|
||||
interactive=enabled,
|
||||
container=False,
|
||||
)
|
||||
with gr.Column():
|
||||
with gr.Row():
|
||||
submit = gr.Button("Submit", interactive=enabled)
|
||||
stop = gr.Button("Stop", interactive=enabled)
|
||||
clear = gr.Button("Clear", interactive=enabled)
|
||||
|
||||
with gr.Row(visible=False):
|
||||
with gr.Group():
|
||||
config_file = gr.File(
|
||||
label="Upload sharding configuration", visible=False
|
||||
)
|
||||
json_view_button = gr.Button(label="View as JSON", visible=False)
|
||||
json_view = gr.JSON(interactive=True, visible=False)
|
||||
json_view_button.click(
|
||||
fn=view_json_file, inputs=[config_file], outputs=[json_view]
|
||||
)
|
||||
submit_event = msg.submit(
|
||||
fn=user,
|
||||
inputs=[msg, chatbot],
|
||||
outputs=[msg, chatbot],
|
||||
show_progress=False,
|
||||
queue=False,
|
||||
).then(
|
||||
fn=chat_fn,
|
||||
inputs=[
|
||||
prompt_prefix,
|
||||
chatbot,
|
||||
model,
|
||||
device,
|
||||
precision,
|
||||
download_vmfb,
|
||||
config_file,
|
||||
],
|
||||
outputs=[chatbot, tokens_time],
|
||||
show_progress=False,
|
||||
queue=True,
|
||||
)
|
||||
submit_click_event = submit.click(
|
||||
fn=user,
|
||||
inputs=[msg, chatbot],
|
||||
outputs=[msg, chatbot],
|
||||
show_progress=False,
|
||||
queue=False,
|
||||
).then(
|
||||
fn=chat_fn,
|
||||
inputs=[
|
||||
prompt_prefix,
|
||||
chatbot,
|
||||
model,
|
||||
device,
|
||||
precision,
|
||||
download_vmfb,
|
||||
config_file,
|
||||
],
|
||||
outputs=[chatbot, tokens_time],
|
||||
show_progress=False,
|
||||
queue=True,
|
||||
)
|
||||
stop.click(
|
||||
fn=None,
|
||||
inputs=None,
|
||||
outputs=None,
|
||||
cancels=[submit_event, submit_click_event],
|
||||
queue=False,
|
||||
)
|
||||
clear.click(lambda: None, None, [chatbot], queue=False)
|
||||
@@ -29,6 +29,10 @@ from apps.stable_diffusion.src.models import (
|
||||
SharkifyStableDiffusionModel,
|
||||
get_vae_encode,
|
||||
)
|
||||
from apps.stable_diffusion.src.utils import (
|
||||
resamplers,
|
||||
resampler_list,
|
||||
)
|
||||
|
||||
|
||||
class Image2ImagePipeline(StableDiffusionPipeline):
|
||||
@@ -91,26 +95,12 @@ class Image2ImagePipeline(StableDiffusionPipeline):
|
||||
# TODO: process with variable HxW combos
|
||||
|
||||
# Pre-process image
|
||||
if resample_type == "Lanczos":
|
||||
resample_type = Image.LANCZOS
|
||||
elif resample_type == "Nearest Neighbor":
|
||||
resample_type = Image.NEAREST
|
||||
elif resample_type == "Bilinear":
|
||||
resample_type = Image.BILINEAR
|
||||
elif resample_type == "Bicubic":
|
||||
resample_type = Image.BICUBIC
|
||||
elif resample_type == "Adaptive":
|
||||
resample_type = Image.ADAPTIVE
|
||||
elif resample_type == "Antialias":
|
||||
resample_type = Image.ANTIALIAS
|
||||
elif resample_type == "Box":
|
||||
resample_type = Image.BOX
|
||||
elif resample_type == "Affine":
|
||||
resample_type = Image.AFFINE
|
||||
elif resample_type == "Cubic":
|
||||
resample_type = Image.CUBIC
|
||||
else: # Fallback to Lanczos
|
||||
resample_type = Image.LANCZOS
|
||||
resample_type = (
|
||||
resamplers[resample_type]
|
||||
if resample_type in resampler_list
|
||||
# Fallback to Lanczos
|
||||
else Image.Resampling.LANCZOS
|
||||
)
|
||||
|
||||
image = image.resize((width, height), resample=resample_type)
|
||||
image_arr = np.stack([np.array(i) for i in (image,)], axis=0)
|
||||
|
||||
@@ -42,3 +42,7 @@ from apps.stable_diffusion.src.utils.utils import (
|
||||
_compile_module,
|
||||
)
|
||||
from apps.stable_diffusion.src.utils.civitai import get_civitai_checkpoint
|
||||
from apps.stable_diffusion.src.utils.resamplers import (
|
||||
resamplers,
|
||||
resampler_list,
|
||||
)
|
||||
|
||||
12
apps/stable_diffusion/src/utils/resamplers.py
Normal file
12
apps/stable_diffusion/src/utils/resamplers.py
Normal file
@@ -0,0 +1,12 @@
|
||||
import PIL.Image as Image
|
||||
|
||||
resamplers = {
|
||||
"Lanczos": Image.Resampling.LANCZOS,
|
||||
"Nearest Neighbor": Image.Resampling.NEAREST,
|
||||
"Bilinear": Image.Resampling.BILINEAR,
|
||||
"Bicubic": Image.Resampling.BICUBIC,
|
||||
"Hamming": Image.Resampling.HAMMING,
|
||||
"Box": Image.Resampling.BOX,
|
||||
}
|
||||
|
||||
resampler_list = resamplers.keys()
|
||||
@@ -2,6 +2,8 @@ import argparse
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
from apps.stable_diffusion.src.utils.resamplers import resampler_list
|
||||
|
||||
|
||||
def path_expand(s):
|
||||
return Path(s).expanduser().resolve()
|
||||
@@ -168,17 +170,7 @@ p.add_argument(
|
||||
"--resample_type",
|
||||
type=str,
|
||||
default="Nearest Neighbor",
|
||||
choices=[
|
||||
"Lanczos",
|
||||
"Nearest Neighbor",
|
||||
"Bilinear",
|
||||
"Bicubic",
|
||||
"Adaptive",
|
||||
"Antialias",
|
||||
"Box",
|
||||
"Affine",
|
||||
"Cubic",
|
||||
],
|
||||
choices=resampler_list,
|
||||
help="The resample type to use when resizing an image before being run "
|
||||
"through stable diffusion.",
|
||||
)
|
||||
@@ -746,8 +738,9 @@ p.add_argument(
|
||||
p.add_argument(
|
||||
"--iree_rocm_target_chip",
|
||||
type=str,
|
||||
default="gfx1100",
|
||||
help="Add the rocm device architecture ex gfx1100, gfx90a, etc. Default gfx1100",
|
||||
default="",
|
||||
help="Add the rocm device architecture ex gfx1100, gfx90a, etc. Use `hipinfo` "
|
||||
"or `iree-run-module --dump_devices=rocm` or `hipinfo` to get desired arch name",
|
||||
)
|
||||
|
||||
args, unknown = p.parse_known_args()
|
||||
|
||||
@@ -1,6 +1,10 @@
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
import torch
|
||||
import os
|
||||
from pathlib import Path
|
||||
import torchvision
|
||||
import time
|
||||
from apps.stable_diffusion.src.utils.stencils import (
|
||||
CannyDetector,
|
||||
OpenposeDetector,
|
||||
@@ -10,6 +14,33 @@ from apps.stable_diffusion.src.utils.stencils import (
|
||||
stencil = {}
|
||||
|
||||
|
||||
def save_img(img):
|
||||
from apps.stable_diffusion.src.utils import (
|
||||
get_generated_imgs_path,
|
||||
get_generated_imgs_todays_subdir,
|
||||
)
|
||||
|
||||
subdir = Path(
|
||||
get_generated_imgs_path(), get_generated_imgs_todays_subdir()
|
||||
)
|
||||
os.makedirs(subdir, exist_ok=True)
|
||||
if isinstance(img, Image.Image):
|
||||
img.save(
|
||||
os.path.join(
|
||||
subdir, "controlnet_" + str(int(time.time())) + ".png"
|
||||
)
|
||||
)
|
||||
elif isinstance(img, np.ndarray):
|
||||
img = Image.fromarray(img)
|
||||
img.save(os.path.join(subdir, str(int(time.time())) + ".png"))
|
||||
else:
|
||||
converter = torchvision.transforms.ToPILImage()
|
||||
for i in img:
|
||||
converter(i).save(
|
||||
os.path.join(subdir, str(int(time.time())) + ".png")
|
||||
)
|
||||
|
||||
|
||||
def HWC3(x):
|
||||
assert x.dtype == np.uint8
|
||||
if x.ndim == 2:
|
||||
@@ -161,6 +192,7 @@ def hint_canny(
|
||||
detected_map = stencil["canny"](
|
||||
input_image, low_threshold, high_threshold
|
||||
)
|
||||
save_img(detected_map)
|
||||
detected_map = HWC3(detected_map)
|
||||
return detected_map
|
||||
|
||||
@@ -176,6 +208,7 @@ def hint_openpose(
|
||||
stencil["openpose"] = OpenposeDetector()
|
||||
|
||||
detected_map, _ = stencil["openpose"](input_image)
|
||||
save_img(detected_map)
|
||||
detected_map = HWC3(detected_map)
|
||||
return detected_map
|
||||
|
||||
@@ -187,6 +220,7 @@ def hint_scribble(image: Image.Image):
|
||||
|
||||
detected_map = np.zeros_like(input_image, dtype=np.uint8)
|
||||
detected_map[np.min(input_image, axis=2) < 127] = 255
|
||||
save_img(detected_map)
|
||||
return detected_map
|
||||
|
||||
|
||||
@@ -199,5 +233,6 @@ def hint_zoedepth(image: Image.Image):
|
||||
stencil["depth"] = ZoeDetector()
|
||||
|
||||
detected_map = stencil["depth"](input_image)
|
||||
save_img(detected_map)
|
||||
detected_map = HWC3(detected_map)
|
||||
return detected_map
|
||||
|
||||
@@ -118,7 +118,7 @@ def compile_through_fx(
|
||||
is_f16=False,
|
||||
f16_input_mask=None,
|
||||
use_tuned=False,
|
||||
save_dir=tempfile.gettempdir(),
|
||||
save_dir="",
|
||||
debug=False,
|
||||
generate_vmfb=True,
|
||||
extra_args=None,
|
||||
@@ -477,7 +477,14 @@ def get_available_devices():
|
||||
f"{device_name} => {driver_name.replace('local', 'cpu')}"
|
||||
)
|
||||
else:
|
||||
device_list.append(f"{device_name} => {driver_name}://{i}")
|
||||
# for drivers with single devices
|
||||
# let the default device be selected without any indexing
|
||||
if len(device_list_dict) == 1:
|
||||
device_list.append(f"{device_name} => {driver_name}")
|
||||
else:
|
||||
device_list.append(
|
||||
f"{device_name} => {driver_name}://{i}"
|
||||
)
|
||||
return device_list
|
||||
|
||||
set_iree_runtime_flags()
|
||||
@@ -556,6 +563,9 @@ def get_opt_flags(model, precision="fp16"):
|
||||
iree_flags += opt_flags[model][is_tuned][precision][
|
||||
"specified_compilation_flags"
|
||||
][device]
|
||||
# Due to lack of support for multi-reduce, we always collapse reduction
|
||||
# dims before dispatch formation right now.
|
||||
iree_flags += ["--iree-flow-collapse-reduction-dims"]
|
||||
return iree_flags
|
||||
|
||||
|
||||
|
||||
@@ -75,11 +75,11 @@ if __name__ == "__main__":
|
||||
# Setup to use shark_tmp for gradio's temporary image files and clear any
|
||||
# existing temporary images there if they exist. Then we can import gradio.
|
||||
# It has to be in this order or gradio ignores what we've set up.
|
||||
from apps.stable_diffusion.web.utils.gradio_configs import (
|
||||
config_gradio_tmp_imgs_folder,
|
||||
from apps.stable_diffusion.web.utils.tmp_configs import (
|
||||
config_tmp,
|
||||
)
|
||||
|
||||
config_gradio_tmp_imgs_folder()
|
||||
config_tmp()
|
||||
import gradio as gr
|
||||
|
||||
# Create custom models folders if they don't exist
|
||||
|
||||
55
apps/stable_diffusion/web/ui/common_ui_events.py
Normal file
55
apps/stable_diffusion/web/ui/common_ui_events.py
Normal file
@@ -0,0 +1,55 @@
|
||||
from apps.stable_diffusion.web.ui.utils import (
|
||||
HSLHue,
|
||||
hsl_color,
|
||||
get_lora_metadata,
|
||||
)
|
||||
|
||||
|
||||
# Answers HTML to show the most frequent tags used when a LoRA was trained,
|
||||
# taken from the metadata of its .safetensors file.
|
||||
def lora_changed(lora_file):
|
||||
# tag frequency percentage, that gets maximum amount of the staring hue
|
||||
TAG_COLOR_THRESHOLD = 0.55
|
||||
# tag frequency percentage, above which a tag is displayed
|
||||
TAG_DISPLAY_THRESHOLD = 0.65
|
||||
# template for the html used to display a tag
|
||||
TAG_HTML_TEMPLATE = '<span class="lora-tag" style="border: 1px solid {color};">{tag}</span>'
|
||||
|
||||
if lora_file == "None":
|
||||
return ["<div><i>No LoRA selected</i></div>"]
|
||||
elif not lora_file.lower().endswith(".safetensors"):
|
||||
return [
|
||||
"<div><i>Only metadata queries for .safetensors files are currently supported</i></div>"
|
||||
]
|
||||
else:
|
||||
metadata = get_lora_metadata(lora_file)
|
||||
if metadata:
|
||||
frequencies = metadata["frequencies"]
|
||||
return [
|
||||
"".join(
|
||||
[
|
||||
f'<div class="lora-model">Trained against weights in: {metadata["model"]}</div>'
|
||||
]
|
||||
+ [
|
||||
TAG_HTML_TEMPLATE.format(
|
||||
color=hsl_color(
|
||||
(tag[1] - TAG_COLOR_THRESHOLD)
|
||||
/ (1 - TAG_COLOR_THRESHOLD),
|
||||
start=HSLHue.RED,
|
||||
end=HSLHue.GREEN,
|
||||
),
|
||||
tag=tag[0],
|
||||
)
|
||||
for tag in frequencies
|
||||
if tag[1] > TAG_DISPLAY_THRESHOLD
|
||||
],
|
||||
)
|
||||
]
|
||||
elif metadata is None:
|
||||
return [
|
||||
"<div><i>This LoRA does not publish tag frequency metadata</i></div>"
|
||||
]
|
||||
else:
|
||||
return [
|
||||
"<div><i>This LoRA has empty tag frequency metadata, or we could not parse it</i></div>"
|
||||
]
|
||||
@@ -105,6 +105,18 @@ body {
|
||||
background-color: var(--background-fill-primary);
|
||||
}
|
||||
|
||||
.generating.svelte-zlszon.svelte-zlszon {
|
||||
border: none;
|
||||
}
|
||||
|
||||
.generating {
|
||||
border: none !important;
|
||||
}
|
||||
|
||||
#chatbot {
|
||||
height: 100% !important;
|
||||
}
|
||||
|
||||
/* display in full width for desktop devices */
|
||||
@media (min-width: 1536px)
|
||||
{
|
||||
@@ -246,10 +258,39 @@ footer {
|
||||
background-color: var(--block-label-background-fill);
|
||||
}
|
||||
|
||||
/* lora tag pills */
|
||||
.lora-tags {
|
||||
border: 1px solid var(--border-color-primary);
|
||||
color: var(--block-info-text-color) !important;
|
||||
padding: var(--block-padding);
|
||||
}
|
||||
|
||||
.lora-tag {
|
||||
display: inline-block;
|
||||
height: 2em;
|
||||
color: rgb(212 212 212) !important;
|
||||
margin-right: 5pt;
|
||||
margin-bottom: 5pt;
|
||||
padding: 2pt 5pt;
|
||||
border-radius: 5pt;
|
||||
white-space: nowrap;
|
||||
}
|
||||
|
||||
.lora-model {
|
||||
margin-bottom: var(--spacing-lg);
|
||||
color: var(--block-info-text-color) !important;
|
||||
line-height: var(--line-sm);
|
||||
}
|
||||
|
||||
/* output gallery tab */
|
||||
.output_parameters_dataframe table.table {
|
||||
/* works around a gradio bug that always shows scrollbars */
|
||||
overflow: clip auto;
|
||||
}
|
||||
|
||||
.output_parameters_dataframe tbody td {
|
||||
font-size: small;
|
||||
line-height: var(--line-xs)
|
||||
line-height: var(--line-xs);
|
||||
}
|
||||
|
||||
.output_icon_button {
|
||||
|
||||
@@ -5,6 +5,7 @@ import gradio as gr
|
||||
import PIL
|
||||
from math import ceil
|
||||
from PIL import Image
|
||||
|
||||
from apps.stable_diffusion.web.ui.utils import (
|
||||
available_devices,
|
||||
nodlogo_loc,
|
||||
@@ -14,6 +15,7 @@ from apps.stable_diffusion.web.ui.utils import (
|
||||
predefined_models,
|
||||
cancel_sd,
|
||||
)
|
||||
from apps.stable_diffusion.web.ui.common_ui_events import lora_changed
|
||||
from apps.stable_diffusion.src import (
|
||||
args,
|
||||
Image2ImagePipeline,
|
||||
@@ -27,6 +29,7 @@ from apps.stable_diffusion.src import (
|
||||
from apps.stable_diffusion.src.utils import (
|
||||
get_generated_imgs_path,
|
||||
get_generation_text_info,
|
||||
resampler_list,
|
||||
)
|
||||
from apps.stable_diffusion.web.utils.common_label_calc import status_label
|
||||
import numpy as np
|
||||
@@ -435,6 +438,11 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
|
||||
label="HuggingFace Model ID",
|
||||
lines=3,
|
||||
)
|
||||
with gr.Row():
|
||||
lora_tags = gr.HTML(
|
||||
value="<div><i>No LoRA selected</i></div>",
|
||||
elem_classes="lora-tags",
|
||||
)
|
||||
with gr.Accordion(label="Advanced Options", open=False):
|
||||
with gr.Row():
|
||||
scheduler = gr.Dropdown(
|
||||
@@ -486,17 +494,7 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
|
||||
)
|
||||
resample_type = gr.Dropdown(
|
||||
value=args.resample_type,
|
||||
choices=[
|
||||
"Lanczos",
|
||||
"Nearest Neighbor",
|
||||
"Bilinear",
|
||||
"Bicubic",
|
||||
"Adaptive",
|
||||
"Antialias",
|
||||
"Box",
|
||||
"Affine",
|
||||
"Cubic",
|
||||
],
|
||||
choices=resampler_list,
|
||||
label="Resample Type",
|
||||
allow_custom_value=True,
|
||||
)
|
||||
@@ -647,3 +645,10 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
|
||||
fn=cancel_sd,
|
||||
cancels=[prompt_submit, neg_prompt_submit, generate_click],
|
||||
)
|
||||
|
||||
lora_weights.change(
|
||||
fn=lora_changed,
|
||||
inputs=[lora_weights],
|
||||
outputs=[lora_tags],
|
||||
queue=True,
|
||||
)
|
||||
|
||||
@@ -4,6 +4,7 @@ import time
|
||||
import sys
|
||||
import gradio as gr
|
||||
from PIL import Image
|
||||
|
||||
from apps.stable_diffusion.web.ui.utils import (
|
||||
available_devices,
|
||||
nodlogo_loc,
|
||||
@@ -13,6 +14,7 @@ from apps.stable_diffusion.web.ui.utils import (
|
||||
predefined_paint_models,
|
||||
cancel_sd,
|
||||
)
|
||||
from apps.stable_diffusion.web.ui.common_ui_events import lora_changed
|
||||
from apps.stable_diffusion.src import (
|
||||
args,
|
||||
InpaintPipeline,
|
||||
@@ -319,6 +321,11 @@ with gr.Blocks(title="Inpainting") as inpaint_web:
|
||||
label="HuggingFace Model ID",
|
||||
lines=3,
|
||||
)
|
||||
with gr.Row():
|
||||
lora_tags = gr.HTML(
|
||||
value="<div><i>No LoRA selected</i></div>",
|
||||
elem_classes="lora-tags",
|
||||
)
|
||||
with gr.Accordion(label="Advanced Options", open=False):
|
||||
with gr.Row():
|
||||
scheduler = gr.Dropdown(
|
||||
@@ -518,3 +525,10 @@ with gr.Blocks(title="Inpainting") as inpaint_web:
|
||||
fn=cancel_sd,
|
||||
cancels=[prompt_submit, neg_prompt_submit, generate_click],
|
||||
)
|
||||
|
||||
lora_weights.change(
|
||||
fn=lora_changed,
|
||||
inputs=[lora_weights],
|
||||
outputs=[lora_tags],
|
||||
queue=True,
|
||||
)
|
||||
|
||||
@@ -3,9 +3,8 @@ import torch
|
||||
import time
|
||||
import gradio as gr
|
||||
from PIL import Image
|
||||
import base64
|
||||
from io import BytesIO
|
||||
from fastapi.exceptions import HTTPException
|
||||
|
||||
from apps.stable_diffusion.web.ui.common_ui_events import lora_changed
|
||||
from apps.stable_diffusion.web.ui.utils import (
|
||||
available_devices,
|
||||
nodlogo_loc,
|
||||
@@ -323,6 +322,11 @@ with gr.Blocks(title="Outpainting") as outpaint_web:
|
||||
label="HuggingFace Model ID",
|
||||
lines=3,
|
||||
)
|
||||
with gr.Row():
|
||||
lora_tags = gr.HTML(
|
||||
value="<div><i>No LoRA selected</i></div>",
|
||||
elem_classes="lora-tags",
|
||||
)
|
||||
with gr.Accordion(label="Advanced Options", open=False):
|
||||
with gr.Row():
|
||||
scheduler = gr.Dropdown(
|
||||
@@ -546,3 +550,10 @@ with gr.Blocks(title="Outpainting") as outpaint_web:
|
||||
fn=cancel_sd,
|
||||
cancels=[prompt_submit, neg_prompt_submit, generate_click],
|
||||
)
|
||||
|
||||
lora_weights.change(
|
||||
fn=lora_changed,
|
||||
inputs=[lora_weights],
|
||||
outputs=[lora_tags],
|
||||
queue=True,
|
||||
)
|
||||
|
||||
@@ -91,7 +91,7 @@ with gr.Blocks() as outputgallery_web:
|
||||
value=gallery_files.value,
|
||||
visible=False,
|
||||
show_label=True,
|
||||
columns=2,
|
||||
columns=4,
|
||||
)
|
||||
|
||||
with gr.Column(scale=4):
|
||||
@@ -204,6 +204,9 @@ with gr.Blocks() as outputgallery_web:
|
||||
),
|
||||
]
|
||||
|
||||
def on_image_columns_change(columns):
|
||||
return gr.Gallery.update(columns=columns)
|
||||
|
||||
def on_select_subdir(subdir) -> list:
|
||||
# evt.value is the subdirectory name
|
||||
new_images = outputgallery_filenames(subdir)
|
||||
@@ -365,53 +368,6 @@ with gr.Blocks() as outputgallery_web:
|
||||
gr.update(),
|
||||
)
|
||||
|
||||
# Unfortunately as of gradio 3.34.0 gr.update against Galleries doesn't
|
||||
# support things set with .style, nor the elem_classes kwarg, so we have
|
||||
# to directly set things up via JavaScript if we want the client to take
|
||||
# notice of our changes to the number of columns after it decides to put
|
||||
# them back to the original number when we change something
|
||||
def js_set_columns_in_browser(timeout_length):
|
||||
return f"""
|
||||
(new_cols) => {{
|
||||
setTimeout(() => {{
|
||||
required_style = "auto ".repeat(new_cols).trim();
|
||||
gallery = document.querySelector('#outputgallery_gallery .grid-container');
|
||||
if (gallery) {{
|
||||
gallery.style.gridTemplateColumns = required_style
|
||||
}}
|
||||
}}, {timeout_length});
|
||||
return []; // prevents console error from gradio
|
||||
}}
|
||||
"""
|
||||
|
||||
# --- Wire handlers up to the actions
|
||||
|
||||
# Many actions reset the number of columns shown in the gallery on the
|
||||
# browser end, so we have to set them back to what we think they should
|
||||
# be after the initial action.
|
||||
#
|
||||
# None of the actions on this tab trigger inference, and we want the
|
||||
# user to be able to do them whilst other tabs have ongoing inference
|
||||
# running. Waiting in the queue behind inference jobs would mean the UI
|
||||
# can't fully respond until the inference tasks complete,
|
||||
# hence queue=False on all of these.
|
||||
set_gallery_columns_immediate = dict(
|
||||
fn=None,
|
||||
inputs=[image_columns],
|
||||
# gradio blanks the UI on Chrome on Linux on gallery select if
|
||||
# I don't put an output here
|
||||
outputs=[dev_null],
|
||||
_js=js_set_columns_in_browser(0),
|
||||
queue=False,
|
||||
)
|
||||
|
||||
# setting columns after selecting a gallery item needs a real
|
||||
# timeout length for the number of columns to actually be applied.
|
||||
# Not really sure why, maybe something has to finish animating?
|
||||
set_gallery_columns_delayed = dict(
|
||||
set_gallery_columns_immediate, _js=js_set_columns_in_browser(250)
|
||||
)
|
||||
|
||||
# clearing images when we need to completely change what's in the
|
||||
# gallery avoids current images being shown replacing piecemeal and
|
||||
# prevents weirdness and errors if the user selects an image during the
|
||||
@@ -423,32 +379,35 @@ with gr.Blocks() as outputgallery_web:
|
||||
queue=False,
|
||||
)
|
||||
|
||||
image_columns.change(**set_gallery_columns_immediate)
|
||||
|
||||
subdirectories.select(**clear_gallery).then(
|
||||
on_select_subdir,
|
||||
[subdirectories],
|
||||
[gallery_files, gallery, logo],
|
||||
queue=False,
|
||||
).then(**set_gallery_columns_immediate)
|
||||
)
|
||||
|
||||
open_subdir.click(
|
||||
on_open_subdir, inputs=[subdirectories], queue=False
|
||||
).then(**set_gallery_columns_immediate)
|
||||
open_subdir.click(on_open_subdir, inputs=[subdirectories], queue=False)
|
||||
|
||||
refresh.click(**clear_gallery).then(
|
||||
on_refresh,
|
||||
[subdirectories],
|
||||
[subdirectories, subdirectory_paths, gallery_files, gallery, logo],
|
||||
queue=False,
|
||||
).then(**set_gallery_columns_immediate)
|
||||
)
|
||||
|
||||
image_columns.change(
|
||||
fn=on_image_columns_change,
|
||||
inputs=[image_columns],
|
||||
outputs=[gallery],
|
||||
queue=False,
|
||||
)
|
||||
|
||||
gallery.select(
|
||||
on_select_image,
|
||||
[gallery_files],
|
||||
[outputgallery_filename, image_parameters],
|
||||
queue=False,
|
||||
).then(**set_gallery_columns_delayed)
|
||||
)
|
||||
|
||||
outputgallery_filename.change(
|
||||
on_outputgallery_filename_change,
|
||||
@@ -477,7 +436,7 @@ with gr.Blocks() as outputgallery_web:
|
||||
open_subdir,
|
||||
],
|
||||
queue=False,
|
||||
).then(**set_gallery_columns_immediate)
|
||||
)
|
||||
|
||||
# We should have been passed a list of components on other tabs that update
|
||||
# when a new image has generated on that tab, so set things up so the user
|
||||
@@ -489,4 +448,4 @@ with gr.Blocks() as outputgallery_web:
|
||||
inputs=[subdirectories, subdirectory_paths, component],
|
||||
outputs=[gallery_files, gallery, logo],
|
||||
queue=False,
|
||||
).then(**set_gallery_columns_immediate)
|
||||
)
|
||||
|
||||
@@ -6,6 +6,7 @@ from transformers import (
|
||||
AutoModelForCausalLM,
|
||||
)
|
||||
from apps.stable_diffusion.web.ui.utils import available_devices
|
||||
from shark.iree_utils.compile_utils import clean_device_info
|
||||
from datetime import datetime as dt
|
||||
import json
|
||||
import sys
|
||||
@@ -151,21 +152,8 @@ def chat(
|
||||
global model_vmfb_key
|
||||
global vicuna_model
|
||||
|
||||
device_id = None
|
||||
model_name, model_path = list(map(str.strip, model.split("=>")))
|
||||
if "cuda" in device:
|
||||
device = "cuda"
|
||||
elif "sync" in device:
|
||||
device = "cpu-sync"
|
||||
elif "task" in device:
|
||||
device = "cpu-task"
|
||||
elif "vulkan" in device:
|
||||
device_id = int(device.split("://")[1])
|
||||
device = "vulkan"
|
||||
elif "rocm" in device:
|
||||
device = "rocm"
|
||||
else:
|
||||
print("unrecognized device")
|
||||
device, device_id = clean_device_info(device)
|
||||
|
||||
from apps.language_models.scripts.vicuna import ShardedVicuna
|
||||
from apps.language_models.scripts.vicuna import UnshardedVicuna
|
||||
@@ -220,10 +208,11 @@ def chat(
|
||||
|
||||
elif "rocm" in device:
|
||||
# add iree rocm flags
|
||||
_extra_args.append(
|
||||
f"--iree-rocm-target-chip={args.iree_rocm_target_chip}"
|
||||
)
|
||||
print(f"extra args = {_extra_args}")
|
||||
if args.iree_rocm_target_chip != "":
|
||||
_extra_args.append(
|
||||
f"--iree-rocm-target-chip={args.iree_rocm_target_chip}"
|
||||
)
|
||||
print(f"extra args = {_extra_args}")
|
||||
|
||||
if model_name == "vicuna4":
|
||||
vicuna_model = ShardedVicuna(
|
||||
@@ -322,17 +311,7 @@ def llm_chat_api(InputData: dict):
|
||||
|
||||
device_id = None
|
||||
if vicuna_model == 0:
|
||||
if "cuda" in device:
|
||||
device = "cuda"
|
||||
elif "sync" in device:
|
||||
device = "cpu-sync"
|
||||
elif "task" in device:
|
||||
device = "cpu-task"
|
||||
elif "vulkan" in device:
|
||||
device_id = int(device.split("://")[1])
|
||||
device = "vulkan"
|
||||
else:
|
||||
print("unrecognized device")
|
||||
device, device_id = clean_device_info(device)
|
||||
|
||||
vicuna_model = UnshardedVicuna(
|
||||
model_name,
|
||||
@@ -457,7 +436,7 @@ with gr.Blocks(title="Chatbot") as stablelm_chat:
|
||||
json_view_button.click(
|
||||
fn=view_json_file, inputs=[config_file], outputs=[json_view]
|
||||
)
|
||||
chatbot = gr.Chatbot(height=500)
|
||||
chatbot = gr.Chatbot(elem_id="chatbot")
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
msg = gr.Textbox(
|
||||
|
||||
@@ -5,15 +5,18 @@ import sys
|
||||
import gradio as gr
|
||||
from PIL import Image
|
||||
from math import ceil
|
||||
|
||||
from apps.stable_diffusion.web.ui.utils import (
|
||||
available_devices,
|
||||
nodlogo_loc,
|
||||
get_custom_model_path,
|
||||
get_custom_model_files,
|
||||
scheduler_list,
|
||||
scheduler_list_cpu_only,
|
||||
predefined_models,
|
||||
cancel_sd,
|
||||
)
|
||||
from apps.stable_diffusion.web.ui.common_ui_events import lora_changed
|
||||
from apps.stable_diffusion.web.utils.metadata import import_png_metadata
|
||||
from apps.stable_diffusion.web.utils.common_label_calc import status_label
|
||||
from apps.stable_diffusion.src import (
|
||||
@@ -29,6 +32,7 @@ from apps.stable_diffusion.src import (
|
||||
from apps.stable_diffusion.src.utils import (
|
||||
get_generated_imgs_path,
|
||||
get_generation_text_info,
|
||||
resampler_list,
|
||||
)
|
||||
|
||||
# set initial values of iree_vulkan_target_triple, use_tuned and import_mlir.
|
||||
@@ -394,6 +398,11 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
|
||||
label="HuggingFace Model ID",
|
||||
lines=3,
|
||||
)
|
||||
with gr.Row():
|
||||
lora_tags = gr.HTML(
|
||||
value="<div><i>No LoRA selected</i></div>",
|
||||
elem_classes="lora-tags",
|
||||
)
|
||||
with gr.Accordion(label="Advanced Options", open=False):
|
||||
with gr.Row():
|
||||
scheduler = gr.Dropdown(
|
||||
@@ -465,50 +474,6 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
|
||||
label="Low VRAM",
|
||||
interactive=True,
|
||||
)
|
||||
with gr.Group():
|
||||
with gr.Row():
|
||||
use_hiresfix = gr.Checkbox(
|
||||
value=args.use_hiresfix,
|
||||
label="Use Hires Fix",
|
||||
interactive=True,
|
||||
)
|
||||
resample_type = gr.Dropdown(
|
||||
value=args.resample_type,
|
||||
choices=[
|
||||
"Lanczos",
|
||||
"Nearest Neighbor",
|
||||
"Bilinear",
|
||||
"Bicubic",
|
||||
"Adaptive",
|
||||
"Antialias",
|
||||
"Box",
|
||||
"Affine",
|
||||
"Cubic",
|
||||
],
|
||||
label="Resample Type",
|
||||
allow_custom_value=True,
|
||||
)
|
||||
hiresfix_height = gr.Slider(
|
||||
384,
|
||||
768,
|
||||
value=args.hiresfix_height,
|
||||
step=8,
|
||||
label="Hires Fix Height",
|
||||
)
|
||||
hiresfix_width = gr.Slider(
|
||||
384,
|
||||
768,
|
||||
value=args.hiresfix_width,
|
||||
step=8,
|
||||
label="Hires Fix Width",
|
||||
)
|
||||
hiresfix_strength = gr.Slider(
|
||||
0,
|
||||
1,
|
||||
value=args.hiresfix_strength,
|
||||
step=0.01,
|
||||
label="Hires Fix Denoising Strength",
|
||||
)
|
||||
with gr.Row():
|
||||
with gr.Column(scale=3):
|
||||
batch_count = gr.Slider(
|
||||
@@ -532,6 +497,41 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
|
||||
args.repeatable_seeds,
|
||||
label="Repeatable Seeds",
|
||||
)
|
||||
with gr.Accordion(label="Hires Fix Options", open=False):
|
||||
with gr.Group():
|
||||
with gr.Row():
|
||||
use_hiresfix = gr.Checkbox(
|
||||
value=args.use_hiresfix,
|
||||
label="Use Hires Fix",
|
||||
interactive=True,
|
||||
)
|
||||
resample_type = gr.Dropdown(
|
||||
value=args.resample_type,
|
||||
choices=resampler_list,
|
||||
label="Resample Type",
|
||||
allow_custom_value=False,
|
||||
)
|
||||
hiresfix_height = gr.Slider(
|
||||
384,
|
||||
768,
|
||||
value=args.hiresfix_height,
|
||||
step=8,
|
||||
label="Hires Fix Height",
|
||||
)
|
||||
hiresfix_width = gr.Slider(
|
||||
384,
|
||||
768,
|
||||
value=args.hiresfix_width,
|
||||
step=8,
|
||||
label="Hires Fix Width",
|
||||
)
|
||||
hiresfix_strength = gr.Slider(
|
||||
0,
|
||||
1,
|
||||
value=args.hiresfix_strength,
|
||||
step=0.01,
|
||||
label="Hires Fix Denoising Strength",
|
||||
)
|
||||
with gr.Row():
|
||||
seed = gr.Textbox(
|
||||
value=args.seed,
|
||||
@@ -676,3 +676,30 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
|
||||
custom_vae,
|
||||
],
|
||||
)
|
||||
|
||||
# SharkEulerDiscrete doesn't work with img2img which hires_fix uses
|
||||
def set_compatible_schedulers(hires_fix_selected):
|
||||
if hires_fix_selected:
|
||||
return gr.Dropdown.update(
|
||||
choices=scheduler_list_cpu_only,
|
||||
value="DEISMultistep",
|
||||
)
|
||||
else:
|
||||
return gr.Dropdown.update(
|
||||
choices=scheduler_list,
|
||||
value="SharkEulerDiscrete",
|
||||
)
|
||||
|
||||
use_hiresfix.change(
|
||||
fn=set_compatible_schedulers,
|
||||
inputs=[use_hiresfix],
|
||||
outputs=[scheduler],
|
||||
queue=False,
|
||||
)
|
||||
|
||||
lora_weights.change(
|
||||
fn=lora_changed,
|
||||
inputs=[lora_weights],
|
||||
outputs=[lora_tags],
|
||||
queue=True,
|
||||
)
|
||||
|
||||
@@ -3,6 +3,7 @@ import torch
|
||||
import time
|
||||
import gradio as gr
|
||||
from PIL import Image
|
||||
|
||||
from apps.stable_diffusion.web.ui.utils import (
|
||||
available_devices,
|
||||
nodlogo_loc,
|
||||
@@ -12,6 +13,7 @@ from apps.stable_diffusion.web.ui.utils import (
|
||||
predefined_upscaler_models,
|
||||
cancel_sd,
|
||||
)
|
||||
from apps.stable_diffusion.web.ui.common_ui_events import lora_changed
|
||||
from apps.stable_diffusion.web.utils.common_label_calc import status_label
|
||||
from apps.stable_diffusion.src import (
|
||||
args,
|
||||
@@ -340,6 +342,11 @@ with gr.Blocks(title="Upscaler") as upscaler_web:
|
||||
label="HuggingFace Model ID",
|
||||
lines=3,
|
||||
)
|
||||
with gr.Row():
|
||||
lora_tags = gr.HTML(
|
||||
value="<div><i>No LoRA selected</i></div>",
|
||||
elem_classes="lora-tags",
|
||||
)
|
||||
with gr.Accordion(label="Advanced Options", open=False):
|
||||
with gr.Row():
|
||||
scheduler = gr.Dropdown(
|
||||
@@ -537,3 +544,10 @@ with gr.Blocks(title="Upscaler") as upscaler_web:
|
||||
fn=cancel_sd,
|
||||
cancels=[prompt_submit, neg_prompt_submit, generate_click],
|
||||
)
|
||||
|
||||
lora_weights.change(
|
||||
fn=lora_changed,
|
||||
inputs=[lora_weights],
|
||||
outputs=[lora_tags],
|
||||
queue=True,
|
||||
)
|
||||
|
||||
@@ -1,10 +1,16 @@
|
||||
import os
|
||||
import sys
|
||||
from apps.stable_diffusion.src import get_available_devices
|
||||
import glob
|
||||
import math
|
||||
import json
|
||||
import safetensors
|
||||
|
||||
from pathlib import Path
|
||||
from apps.stable_diffusion.src import args
|
||||
from dataclasses import dataclass
|
||||
from enum import IntEnum
|
||||
|
||||
from apps.stable_diffusion.src import get_available_devices
|
||||
import apps.stable_diffusion.web.utils.global_obj as global_obj
|
||||
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_utils import (
|
||||
SD_STATE_CANCEL,
|
||||
@@ -28,6 +34,15 @@ class Config:
|
||||
ondemand: str # should this be expecting a bool instead?
|
||||
|
||||
|
||||
class HSLHue(IntEnum):
|
||||
RED = 0
|
||||
YELLOW = 60
|
||||
GREEN = 120
|
||||
CYAN = 180
|
||||
BLUE = 240
|
||||
MAGENTA = 300
|
||||
|
||||
|
||||
custom_model_filetypes = (
|
||||
"*.ckpt",
|
||||
"*.safetensors",
|
||||
@@ -161,6 +176,69 @@ def get_custom_vae_or_lora_weights(weights, hf_id, model):
|
||||
return use_weight
|
||||
|
||||
|
||||
def hsl_color(alpha: float, start, end):
|
||||
b = (end - start) * (alpha if alpha > 0 else 0)
|
||||
result = b + start
|
||||
|
||||
# Return a CSS HSL string
|
||||
return f"hsl({math.floor(result)}, 80%, 35%)"
|
||||
|
||||
|
||||
def get_lora_metadata(lora_filename):
|
||||
# get the metadata from the file
|
||||
filename = get_custom_model_pathfile(lora_filename, "lora")
|
||||
with safetensors.safe_open(filename, framework="pt", device="cpu") as f:
|
||||
metadata = f.metadata()
|
||||
|
||||
# guard clause for if there isn't any metadata
|
||||
if not metadata:
|
||||
return None
|
||||
|
||||
# metadata is a dictionary of strings, the values of the keys we're
|
||||
# interested in are actually json, and need to be loaded as such
|
||||
tag_frequencies = json.loads(metadata.get("ss_tag_frequency", str("{}")))
|
||||
dataset_dirs = json.loads(metadata.get("ss_dataset_dirs", str("{}")))
|
||||
tag_dirs = [dir for dir in tag_frequencies.keys()]
|
||||
|
||||
# gather the tag frequency information for all the datasets trained
|
||||
all_frequencies = {}
|
||||
for dataset in tag_dirs:
|
||||
frequencies = sorted(
|
||||
[entry for entry in tag_frequencies[dataset].items()],
|
||||
reverse=True,
|
||||
key=lambda x: x[1],
|
||||
)
|
||||
|
||||
# get a figure for the total number of images processed for this dataset
|
||||
# either then number actually listed or in its dataset_dir entry or
|
||||
# the highest frequency's number if that doesn't exist
|
||||
img_count = dataset_dirs.get(dir, {}).get(
|
||||
"img_count", frequencies[0][1]
|
||||
)
|
||||
|
||||
# add the dataset frequencies to the overall frequencies replacing the
|
||||
# frequency counts on the tags with a percentage/ratio
|
||||
all_frequencies.update(
|
||||
[(entry[0], entry[1] / img_count) for entry in frequencies]
|
||||
)
|
||||
|
||||
trained_model_id = " ".join(
|
||||
[
|
||||
metadata.get("ss_sd_model_hash", ""),
|
||||
metadata.get("ss_sd_model_name", ""),
|
||||
metadata.get("ss_base_model_version", ""),
|
||||
]
|
||||
).strip()
|
||||
|
||||
# return the topmost <count> of all frequencies in all datasets
|
||||
return {
|
||||
"model": trained_model_id,
|
||||
"frequencies": sorted(
|
||||
all_frequencies.items(), reverse=True, key=lambda x: x[1]
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
def cancel_sd():
|
||||
# Try catch it, as gc can delete global_obj.sd_obj while switching model
|
||||
try:
|
||||
|
||||
@@ -2,26 +2,25 @@ import os
|
||||
import sys
|
||||
import webview
|
||||
import webview.util
|
||||
import winreg
|
||||
import socket
|
||||
|
||||
from contextlib import closing
|
||||
from multiprocessing import Process
|
||||
from tkinter import Tk
|
||||
|
||||
from apps.stable_diffusion.src import args
|
||||
|
||||
|
||||
def webview2_installed():
|
||||
if sys.platform != "win32":
|
||||
return False
|
||||
|
||||
# On windows we want to ensure we have MS webview2 available so we don't fall back
|
||||
# to MSHTML (aka ye olde Internet Explorer) which is deprecated by pywebview, and
|
||||
# apparently causes SHARK not to load in properly.
|
||||
|
||||
# Checking these registry entries is how Microsoft says to detect a webview2 installation:
|
||||
# https://learn.microsoft.com/en-us/microsoft-edge/webview2/concepts/distribution
|
||||
|
||||
if sys.platform != "win32":
|
||||
return False
|
||||
import winreg
|
||||
|
||||
path = r"SOFTWARE\WOW6432Node\Microsoft\EdgeUpdate\Clients\{F3017226-FE2A-4295-8BDF-00C3A9A7E4C5}"
|
||||
|
||||
@@ -54,6 +53,8 @@ def webview2_installed():
|
||||
|
||||
|
||||
def window(address):
|
||||
from tkinter import Tk
|
||||
|
||||
window = Tk()
|
||||
|
||||
# get screen width and height of display and make it more reasonably
|
||||
|
||||
@@ -5,11 +5,25 @@ from time import time
|
||||
shark_tmp = os.path.join(os.getcwd(), "shark_tmp/")
|
||||
|
||||
|
||||
def config_gradio_tmp_imgs_folder():
|
||||
# create shark_tmp if it does not exist
|
||||
if not os.path.exists(shark_tmp):
|
||||
os.mkdir(shark_tmp)
|
||||
def clear_tmp_mlir():
|
||||
cleanup_start = time()
|
||||
print(
|
||||
"Clearing .mlir temporary files from a prior run. This may take some time..."
|
||||
)
|
||||
mlir_files = [
|
||||
filename
|
||||
for filename in os.listdir(shark_tmp)
|
||||
if os.path.isfile(os.path.join(shark_tmp, filename))
|
||||
and filename.endswith(".mlir")
|
||||
]
|
||||
for filename in mlir_files:
|
||||
os.remove(shark_tmp + filename)
|
||||
print(
|
||||
f"Clearing .mlir temporary files took {time() - cleanup_start:.4f} seconds."
|
||||
)
|
||||
|
||||
|
||||
def clear_tmp_imgs():
|
||||
# tell gradio to use a directory under shark_tmp for its temporary
|
||||
# image files unless somewhere else has been set
|
||||
if "GRADIO_TEMP_DIR" not in os.environ:
|
||||
@@ -52,3 +66,12 @@ def config_gradio_tmp_imgs_folder():
|
||||
)
|
||||
else:
|
||||
print("No temporary images files to clear.")
|
||||
|
||||
|
||||
def config_tmp():
|
||||
# create shark_tmp if it does not exist
|
||||
if not os.path.exists(shark_tmp):
|
||||
os.mkdir(shark_tmp)
|
||||
|
||||
clear_tmp_mlir()
|
||||
clear_tmp_imgs()
|
||||
@@ -78,7 +78,10 @@ def test_loop(
|
||||
os.mkdir("./test_images/golden")
|
||||
get_inpaint_inputs()
|
||||
hf_model_names = model_config_dicts[0].values()
|
||||
tuned_options = ["--no-use_tuned", "--use_tuned"]
|
||||
tuned_options = [
|
||||
"--no-use_tuned",
|
||||
"--use_tuned",
|
||||
]
|
||||
import_options = ["--import_mlir", "--no-import_mlir"]
|
||||
prompt_text = "--prompt=cyberpunk forest by Salvador Dali"
|
||||
inpaint_prompt_text = "--prompt=Face of a yellow cat, high resolution, sitting on a park bench"
|
||||
@@ -112,6 +115,8 @@ def test_loop(
|
||||
and use_tune == tuned_options[1]
|
||||
):
|
||||
continue
|
||||
elif use_tune == tuned_options[1]:
|
||||
continue
|
||||
command = (
|
||||
[
|
||||
executable, # executable is the python from the venv used to run this
|
||||
|
||||
@@ -22,33 +22,33 @@ This does mean however, that on a brand new fresh install of SHARK that has not
|
||||
|
||||
* Make sure you have suitable drivers for your graphics card installed. See the prerequisties section of the [README](https://github.com/nod-ai/SHARK#readme).
|
||||
* Download the latest SHARK studio .exe from [here](https://github.com/nod-ai/SHARK/releases) or follow the instructions in the [README](https://github.com/nod-ai/SHARK#readme) for an advanced, Linux or Mac install.
|
||||
* Run SHARK from terminal/PowerShell with the `--api` flag. Since koboldcpp also expects both CORS support and the image generator to be running on port `7860` rather than SHARK default of `8080`, also include both the `--api_cors_origin` flag with a suitable origin (use `="*"` to enable all origins) and `--server_port=7860` on the command line. (See the if you want to run SHARK on a different port)
|
||||
* Run SHARK from terminal/PowerShell with the `--api` flag. Since koboldcpp also expects both CORS support and the image generator to be running on port `7860` rather than SHARK default of `8080`, also include both the `--api_accept_origin` flag with a suitable origin (use `="*"` to enable all origins) and `--server_port=7860` on the command line. (See the if you want to run SHARK on a different port)
|
||||
|
||||
```powershell
|
||||
## Run the .exe in API mode, with CORS support, on the A1111 endpoint port:
|
||||
.\node_ai_shark_studio_<date>_<ver>.exe --api --api_cors_origin="*" --server_port=7860
|
||||
.\node_ai_shark_studio_<date>_<ver>.exe --api --api_accept_origin="*" --server_port=7860
|
||||
|
||||
## Run trom the base directory of a source clone of SHARK on Windows:
|
||||
.\setup_venv.ps1
|
||||
python .\apps\stable_diffusion\web\index.py --api --api_cors_origin="*" --server_port=7860
|
||||
python .\apps\stable_diffusion\web\index.py --api --api_accept_origin="*" --server_port=7860
|
||||
|
||||
## Run a the base directory of a source clone of SHARK on Linux:
|
||||
./setup_venv.sh
|
||||
source shark.venv/bin/activate
|
||||
python ./apps/stable_diffusion/web/index.py --api --api_cors_origin="*" --server_port=7860
|
||||
python ./apps/stable_diffusion/web/index.py --api --api_accept_origin="*" --server_port=7860
|
||||
|
||||
## An example giving improved performance on AMD cards using vulkan, that runs on the same port as A1111
|
||||
.\node_ai_shark_studio_20320901_2525.exe --api --api_cors_origin="*" --device_allocator="caching" --server_port=7860
|
||||
.\node_ai_shark_studio_20320901_2525.exe --api --api_accept_origin="*" --device_allocator="caching" --server_port=7860
|
||||
|
||||
## Since the api respects most applicable SHARK command line arguments for options not specified,
|
||||
## or currently unimplemented by API, there might be some you want to set, as listed in `--help`
|
||||
.\node_ai_shark_studio_20320901_2525.exe --help
|
||||
|
||||
## For instance, the example above, but with a a custom VAE specified
|
||||
.\node_ai_shark_studio_20320901_2525.exe --api --api_cors_origin="*" --device_allocator="caching" --server_port=7860 --custom_vae="clearvae_v23.safetensors"
|
||||
.\node_ai_shark_studio_20320901_2525.exe --api --api_accept_origin="*" --device_allocator="caching" --server_port=7860 --custom_vae="clearvae_v23.safetensors"
|
||||
|
||||
## An example with multiple specific CORS origins
|
||||
python apps/stable_diffusion/web/index.py --api --api_cors_origin="koboldcpp.example.com:7001" --api_cors_origin="koboldcpp.example.com:7002" --server_port=7860
|
||||
python apps/stable_diffusion/web/index.py --api --api_accept_origin="koboldcpp.example.com:7001" --api_accept_origin="koboldcpp.example.com:7002" --server_port=7860
|
||||
```
|
||||
|
||||
SHARK should start in server mode, and you should see something like this:
|
||||
@@ -83,11 +83,11 @@ SHARK should start in server mode, and you should see something like this:
|
||||
|
||||
* There is one final piece of image generation configuration within Koboldcpp you might want to do. This is also in the generate images section of advanced settings. Here there is, not very obviously, a 'style' button:
|
||||
|
||||

|
||||

|
||||
|
||||
This will bring up a dialog box where you can enter a short text that will sent as a prefix to the Prompt sent to SHARK:
|
||||
|
||||

|
||||

|
||||
|
||||
|
||||
## Connecting to SHARK on a different address or port
|
||||
|
||||
@@ -17,6 +17,7 @@ pytest-forked
|
||||
Pillow
|
||||
parameterized
|
||||
|
||||
#shark-turbine @ git+https://github.com/nod-ai/SHARK-Turbine.git@main
|
||||
# Add transformers, diffusers and scipy since it most commonly used
|
||||
tokenizers==0.13.3
|
||||
transformers
|
||||
@@ -42,6 +43,7 @@ joblib # for langchain
|
||||
timm # for MiniGPT4
|
||||
langchain
|
||||
einops # for zoedepth
|
||||
pydantic==2.4.1 # pin until pyinstaller-hooks-contrib works with beta versions
|
||||
|
||||
# Keep PyInstaller at the end. Sometimes Windows Defender flags it but most folks can continue even if it errors
|
||||
pefile
|
||||
@@ -49,3 +51,7 @@ pyinstaller
|
||||
|
||||
# vicuna quantization
|
||||
brevitas @ git+https://github.com/Xilinx/brevitas.git@56edf56a3115d5ac04f19837b388fd7d3b1ff7ea
|
||||
|
||||
# For quantized GPTQ models
|
||||
optimum
|
||||
auto_gptq
|
||||
|
||||
@@ -89,7 +89,7 @@ else {python -m venv .\shark.venv\}
|
||||
python -m pip install --upgrade pip
|
||||
pip install wheel
|
||||
pip install -r requirements.txt
|
||||
pip install --pre torch-mlir torch --extra-index-url https://download.pytorch.org/whl/nightly/cpu -f https://llvm.github.io/torch-mlir/package-index/
|
||||
pip install --pre torch-mlir torchvision torch --extra-index-url https://download.pytorch.org/whl/nightly/cpu -f https://llvm.github.io/torch-mlir/package-index/
|
||||
pip install --upgrade -f https://nod-ai.github.io/SRT/pip-release-links.html iree-compiler iree-runtime
|
||||
Write-Host "Building SHARK..."
|
||||
pip install -e . -f https://llvm.github.io/torch-mlir/package-index/ -f https://nod-ai.github.io/SRT/pip-release-links.html
|
||||
|
||||
@@ -111,7 +111,7 @@ else
|
||||
fi
|
||||
if [[ -z "${NO_BACKEND}" ]]; then
|
||||
echo "Installing ${RUNTIME}..."
|
||||
$PYTHON -m pip install --pre --upgrade --find-links ${RUNTIME} iree-compiler iree-runtime
|
||||
$PYTHON -m pip install --pre --upgrade --no-index --find-links ${RUNTIME} iree-compiler iree-runtime
|
||||
else
|
||||
echo "Not installing a backend, please make sure to add your backend to PYTHONPATH"
|
||||
fi
|
||||
|
||||
@@ -19,9 +19,12 @@ import sys
|
||||
import subprocess
|
||||
|
||||
|
||||
def run_cmd(cmd, debug=False):
|
||||
def run_cmd(cmd, debug=False, raise_err=False):
|
||||
"""
|
||||
Inputs: cli command string.
|
||||
Inputs:
|
||||
cmd : cli command string.
|
||||
debug : if True, prints debug info
|
||||
raise_err : if True, raise exception to caller
|
||||
"""
|
||||
if debug:
|
||||
print("IREE run command: \n\n")
|
||||
@@ -39,8 +42,11 @@ def run_cmd(cmd, debug=False):
|
||||
stderr = result.stderr.decode()
|
||||
return stdout, stderr
|
||||
except subprocess.CalledProcessError as e:
|
||||
print(e.output)
|
||||
sys.exit(f"Exiting program due to error running {cmd}")
|
||||
if raise_err:
|
||||
raise Exception from e
|
||||
else:
|
||||
print(e.output)
|
||||
sys.exit(f"Exiting program due to error running {cmd}")
|
||||
|
||||
|
||||
def iree_device_map(device):
|
||||
|
||||
@@ -16,7 +16,6 @@ import numpy as np
|
||||
import os
|
||||
import re
|
||||
import tempfile
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import iree.runtime as ireert
|
||||
@@ -32,24 +31,7 @@ from .benchmark_utils import *
|
||||
# Get the iree-compile arguments given device.
|
||||
def get_iree_device_args(device, extra_args=[]):
|
||||
print("Configuring for device:" + device)
|
||||
device_uri = device.split("://")
|
||||
if len(device_uri) > 1:
|
||||
if device_uri[0] not in ["vulkan"]:
|
||||
print(
|
||||
f"Specific device selection only supported for vulkan now."
|
||||
f"Proceeding with {device} as device."
|
||||
)
|
||||
# device_uri can be device_num or device_path.
|
||||
# assuming number of devices for a single driver will be not be >99
|
||||
if len(device_uri[1]) <= 2:
|
||||
# expected to be device index in range 0 - 99
|
||||
device_num = int(device_uri[1])
|
||||
else:
|
||||
# expected to be device path
|
||||
device_num = device_uri[1]
|
||||
|
||||
else:
|
||||
device_num = 0
|
||||
device, device_num = clean_device_info(device)
|
||||
|
||||
if "cpu" in device:
|
||||
from shark.iree_utils.cpu_utils import get_iree_cpu_args
|
||||
@@ -63,30 +45,52 @@ def get_iree_device_args(device, extra_args=[]):
|
||||
+ data_tiling_flag
|
||||
+ u_kernel_flag
|
||||
+ stack_size_flag
|
||||
+ ["--iree-flow-enable-quantized-matmul-reassociation"]
|
||||
+ ["--iree-llvmcpu-enable-quantized-matmul-reassociation"]
|
||||
+ ["--iree-global-opt-enable-quantized-matmul-reassociation"]
|
||||
)
|
||||
if device_uri[0] == "cuda":
|
||||
if device == "cuda":
|
||||
from shark.iree_utils.gpu_utils import get_iree_gpu_args
|
||||
|
||||
return get_iree_gpu_args()
|
||||
if device_uri[0] == "vulkan":
|
||||
if device == "vulkan":
|
||||
from shark.iree_utils.vulkan_utils import get_iree_vulkan_args
|
||||
|
||||
return get_iree_vulkan_args(
|
||||
device_num=device_num, extra_args=extra_args
|
||||
)
|
||||
if device_uri[0] == "metal":
|
||||
if device == "metal":
|
||||
from shark.iree_utils.metal_utils import get_iree_metal_args
|
||||
|
||||
return get_iree_metal_args(extra_args=extra_args)
|
||||
if device_uri[0] == "rocm":
|
||||
if device == "rocm":
|
||||
from shark.iree_utils.gpu_utils import get_iree_rocm_args
|
||||
|
||||
return get_iree_rocm_args(extra_args=extra_args)
|
||||
return get_iree_rocm_args(device_num=device_num, extra_args=extra_args)
|
||||
return []
|
||||
|
||||
|
||||
def clean_device_info(raw_device):
|
||||
# return appropriate device and device_id for consumption by Studio pipeline
|
||||
# Multiple devices only supported for vulkan and rocm (as of now).
|
||||
# default device must be selected for all others
|
||||
|
||||
device_id = None
|
||||
device = (
|
||||
raw_device
|
||||
if "=>" not in raw_device
|
||||
else raw_device.split("=>")[1].strip()
|
||||
)
|
||||
if "://" in device:
|
||||
device, device_id = device.split("://")
|
||||
if len(device_id) <= 2:
|
||||
device_id = int(device_id)
|
||||
|
||||
if device not in ["rocm", "vulkan"]:
|
||||
device_id = ""
|
||||
if device in ["rocm", "vulkan"] and device_id == None:
|
||||
device_id = 0
|
||||
return device, device_id
|
||||
|
||||
|
||||
# Get the iree-compiler arguments given frontend.
|
||||
def get_iree_frontend_args(frontend):
|
||||
if frontend in ["torch", "pytorch", "linalg", "tm_tensor"]:
|
||||
@@ -321,6 +325,8 @@ def compile_module_to_flatbuffer(
|
||||
input_type = "tosa"
|
||||
elif frontend in ["tm_tensor"]:
|
||||
input_type = ireec.InputType.TM_TENSOR
|
||||
elif frontend in ["torch", "pytorch"]:
|
||||
input_type = "torch"
|
||||
|
||||
if compile_str:
|
||||
flatbuffer_blob = ireec.compile_str(
|
||||
@@ -632,7 +638,7 @@ def get_results(
|
||||
def get_iree_runtime_config(device):
|
||||
device = iree_device_map(device)
|
||||
haldriver = ireert.get_driver(device)
|
||||
if device == "metal" and shark_args.device_allocator == "caching":
|
||||
if "metal" in device and shark_args.device_allocator == "caching":
|
||||
print(
|
||||
"[WARNING] metal devices can not have a `caching` allocator."
|
||||
"\nUsing default allocator `None`"
|
||||
@@ -640,7 +646,9 @@ def get_iree_runtime_config(device):
|
||||
haldevice = haldriver.create_device_by_uri(
|
||||
device,
|
||||
# metal devices have a failure with caching allocators atm. blcking this util it gets fixed upstream.
|
||||
allocators=shark_args.device_allocator if device != "metal" else None,
|
||||
allocators=shark_args.device_allocator
|
||||
if "metal" not in device
|
||||
else None,
|
||||
)
|
||||
config = ireert.Config(device=haldevice)
|
||||
return config
|
||||
|
||||
@@ -18,7 +18,11 @@ import functools
|
||||
import iree.runtime as ireert
|
||||
import ctypes
|
||||
import sys
|
||||
from subprocess import CalledProcessError
|
||||
from shark.parser import shark_args
|
||||
from shark.iree_utils._common import run_cmd
|
||||
|
||||
# TODO: refactor to rocm and cuda utils
|
||||
|
||||
|
||||
# Get the default gpu args given the architecture.
|
||||
@@ -39,24 +43,83 @@ def get_iree_gpu_args():
|
||||
return []
|
||||
|
||||
|
||||
def check_rocm_device_arch_in_args(extra_args):
|
||||
# Check if the target arch flag for rocm device present in extra_args
|
||||
for flag in extra_args:
|
||||
if "iree-rocm-target-chip" in flag:
|
||||
flag_arch = flag.split("=")[1]
|
||||
return flag_arch
|
||||
return None
|
||||
|
||||
|
||||
def get_rocm_device_arch(device_num=0, extra_args=[]):
|
||||
# ROCM Device Arch selection:
|
||||
# 1 : User given device arch using `--iree-rocm-target-chip` flag
|
||||
# 2 : Device arch from `iree-run-module --dump_devices=rocm` for device on index <device_num>
|
||||
# 3 : default arch : gfx1100
|
||||
|
||||
arch_in_flag = check_rocm_device_arch_in_args(extra_args)
|
||||
if arch_in_flag is not None:
|
||||
print(
|
||||
f"User Specified rocm target device arch from flag : {arch_in_flag} will be used"
|
||||
)
|
||||
return arch_in_flag
|
||||
|
||||
arch_in_device_dump = None
|
||||
|
||||
# get rocm arch from iree dump devices
|
||||
def get_devices_info_from_dump(dump):
|
||||
from os import linesep
|
||||
|
||||
dump_clean = list(
|
||||
filter(
|
||||
lambda s: "--device=rocm" in s or "gpu-arch-name:" in s,
|
||||
dump.split(linesep),
|
||||
)
|
||||
)
|
||||
arch_pairs = [
|
||||
(
|
||||
dump_clean[i].split("=")[1].strip(),
|
||||
dump_clean[i + 1].split(":")[1].strip(),
|
||||
)
|
||||
for i in range(0, len(dump_clean), 2)
|
||||
]
|
||||
return arch_pairs
|
||||
|
||||
dump_device_info = None
|
||||
try:
|
||||
dump_device_info = run_cmd(
|
||||
"iree-run-module --dump_devices=rocm", raise_err=True
|
||||
)
|
||||
except Exception as e:
|
||||
print("could not execute `iree-run-module --dump_devices=rocm`")
|
||||
|
||||
if dump_device_info is not None:
|
||||
device_num = 0 if device_num is None else device_num
|
||||
device_arch_pairs = get_devices_info_from_dump(dump_device_info[0])
|
||||
if len(device_arch_pairs) > device_num: # can find arch in the list
|
||||
arch_in_device_dump = device_arch_pairs[device_num][1]
|
||||
|
||||
if arch_in_device_dump is not None:
|
||||
print(f"Found ROCm device arch : {arch_in_device_dump}")
|
||||
return arch_in_device_dump
|
||||
|
||||
default_rocm_arch = "gfx1100"
|
||||
print(
|
||||
"Did not find ROCm architecture from `--iree-rocm-target-chip` flag"
|
||||
"\n or from `iree-run-module --dump_devices=rocm` command."
|
||||
f"\nUsing {default_rocm_arch} as ROCm arch for compilation."
|
||||
)
|
||||
return default_rocm_arch
|
||||
|
||||
|
||||
# Get the default gpu args given the architecture.
|
||||
def get_iree_rocm_args(extra_args=[]):
|
||||
def get_iree_rocm_args(device_num=0, extra_args=[]):
|
||||
ireert.flags.FUNCTION_INPUT_VALIDATION = False
|
||||
rocm_flags = ["--iree-rocm-link-bc=true"]
|
||||
|
||||
# Add the target arch flag for rocm device
|
||||
flag_present = False
|
||||
for flag in extra_args:
|
||||
if "iree-rocm-target-chip" in flag:
|
||||
flag_present = True
|
||||
print(
|
||||
f"found rocm target device arch from flag : {flag.split('=')[1]}"
|
||||
)
|
||||
if not flag_present:
|
||||
print(
|
||||
"Failed to find ROCm architecture from hipinfo / rocminfo. Defaulting to gfx1100."
|
||||
)
|
||||
rocm_arch = "gfx1100"
|
||||
if check_rocm_device_arch_in_args(extra_args) is None:
|
||||
rocm_arch = get_rocm_device_arch(device_num, extra_args)
|
||||
rocm_flags.append(f"--iree-rocm-target-chip={rocm_arch}")
|
||||
|
||||
return rocm_flags
|
||||
|
||||
@@ -27,22 +27,35 @@ from shark.parser import shark_args
|
||||
def get_all_vulkan_devices():
|
||||
from iree.runtime import get_driver
|
||||
|
||||
driver = get_driver("vulkan")
|
||||
device_list_src = driver.query_available_devices()
|
||||
try:
|
||||
driver = get_driver("vulkan")
|
||||
device_list_src = driver.query_available_devices()
|
||||
except:
|
||||
device_list_src = {}
|
||||
|
||||
return [d["name"] for d in device_list_src]
|
||||
|
||||
|
||||
@functools.cache
|
||||
def get_vulkan_device_name(device_num=0):
|
||||
vulkaninfo_list = get_all_vulkan_devices()
|
||||
if len(vulkaninfo_list) == 0:
|
||||
raise ValueError("No device name found in VulkanInfo!")
|
||||
if len(vulkaninfo_list) > 1:
|
||||
print("Following devices found:")
|
||||
for i, dname in enumerate(vulkaninfo_list):
|
||||
print(f"{i}. {dname}")
|
||||
print(f"Choosing device: {vulkaninfo_list[device_num]}")
|
||||
return vulkaninfo_list[device_num]
|
||||
if isinstance(device_num, int):
|
||||
vulkaninfo_list = get_all_vulkan_devices()
|
||||
|
||||
if len(vulkaninfo_list) == 0:
|
||||
raise ValueError("No device name found in VulkanInfo!")
|
||||
if len(vulkaninfo_list) > 1:
|
||||
print("Following devices found:")
|
||||
for i, dname in enumerate(vulkaninfo_list):
|
||||
print(f"{i}. {dname}")
|
||||
print(f"Choosing device: vulkan://{device_num}")
|
||||
vulkan_device_name = vulkaninfo_list[device_num]
|
||||
else:
|
||||
from iree.runtime import get_driver
|
||||
|
||||
vulkan_device_driver = get_driver(device_num)
|
||||
vulkan_device_name = vulkan_device_driver.query_available_devices()[0]
|
||||
print(vulkan_device_name)
|
||||
return vulkan_device_name
|
||||
|
||||
|
||||
def get_os_name():
|
||||
|
||||
@@ -7,7 +7,7 @@ import torch_mlir
|
||||
from torch_mlir.compiler_utils import run_pipeline_with_repro_report
|
||||
from typing import List, Tuple
|
||||
from io import BytesIO
|
||||
from brevitas_examples.llm.llm_quant.quantize import quantize_model
|
||||
from brevitas_examples.common.generative.quantize import quantize_model
|
||||
from brevitas_examples.llm.llm_quant.run_utils import get_model_impl
|
||||
|
||||
|
||||
@@ -84,7 +84,7 @@ def compile_int_precision(
|
||||
weight_quant_type="asym",
|
||||
weight_bit_width=weight_bit_width,
|
||||
weight_param_method="stats",
|
||||
weight_scale_precision="float",
|
||||
weight_scale_precision="float_scale",
|
||||
weight_quant_granularity="per_group",
|
||||
weight_group_size=weight_group_size,
|
||||
quantize_weight_zero_point=False,
|
||||
|
||||
@@ -800,15 +800,17 @@ def save_mlir(
|
||||
model_name,
|
||||
mlir_dialect="linalg",
|
||||
frontend="torch",
|
||||
dir=tempfile.gettempdir(),
|
||||
dir="",
|
||||
):
|
||||
model_name_mlir = (
|
||||
model_name + "_" + frontend + "_" + mlir_dialect + ".mlir"
|
||||
)
|
||||
if dir == "":
|
||||
dir = tempfile.gettempdir()
|
||||
dir = os.path.join(".", "shark_tmp")
|
||||
mlir_path = os.path.join(dir, model_name_mlir)
|
||||
print(f"saving {model_name_mlir} to {dir}")
|
||||
if not os.path.exists(dir):
|
||||
os.makedirs(dir)
|
||||
if frontend == "torch":
|
||||
with open(mlir_path, "wb") as mlir_file:
|
||||
mlir_file.write(mlir_module)
|
||||
|
||||
@@ -1,21 +1,19 @@
|
||||
bert-base-uncased,linalg,torch,1e-2,1e-3,default,None,False,True,False,"",""
|
||||
bert-base-uncased_fp16,linalg,torch,1e-1,1e-1,default,None,True,True,True,"",""
|
||||
bert-large-uncased,linalg,torch,1e-2,1e-3,default,None,False,True,False,"",""
|
||||
facebook/deit-small-distilled-patch16-224,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,False,True,False,"Fails during iree-compile.",""
|
||||
google/vit-base-patch16-224,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,False,True,False,"https://github.com/nod-ai/SHARK/issues/311",""
|
||||
microsoft/beit-base-patch16-224-pt22k-ft22k,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,False,True,False,"https://github.com/nod-ai/SHARK/issues/390","macos"
|
||||
microsoft/MiniLM-L12-H384-uncased,linalg,torch,1e-2,1e-3,default,None,False,True,False,"",""
|
||||
google/mobilebert-uncased,linalg,torch,1e-2,1e-3,default,None,False,True,False,"https://github.com/nod-ai/SHARK/issues/344","macos"
|
||||
mobilenet_v3_small,linalg,torch,1e-1,1e-2,default,nhcw-nhwc,True,True,True,"https://github.com/nod-ai/SHARK/issues/388, https://github.com/nod-ai/SHARK/issues/1487","macos"
|
||||
bert-base-uncased,linalg,torch,1e-2,1e-3,default,None,False,False,False,"",""
|
||||
bert-large-uncased,linalg,torch,1e-2,1e-3,default,None,False,False,False,"",""
|
||||
facebook/deit-small-distilled-patch16-224,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,False,False,False,"Fails during iree-compile.",""
|
||||
google/vit-base-patch16-224,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,True,True,True,"https://github.com/nod-ai/SHARK/issues/311",""
|
||||
microsoft/beit-base-patch16-224-pt22k-ft22k,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,False,False,False,"https://github.com/nod-ai/SHARK/issues/390","macos"
|
||||
microsoft/MiniLM-L12-H384-uncased,linalg,torch,1e-2,1e-3,default,None,False,False,False,"",""
|
||||
google/mobilebert-uncased,linalg,torch,1e-2,1e-3,default,None,False,False,False,"https://github.com/nod-ai/SHARK/issues/344","macos"
|
||||
mobilenet_v3_small,linalg,torch,1e-1,1e-2,default,nhcw-nhwc,False,False,False,"https://github.com/nod-ai/SHARK/issues/388, https://github.com/nod-ai/SHARK/issues/1487","macos"
|
||||
nvidia/mit-b0,linalg,torch,1e-2,1e-3,default,None,True,True,True,"https://github.com/nod-ai/SHARK/issues/343,https://github.com/nod-ai/SHARK/issues/1487","macos"
|
||||
resnet101,linalg,torch,1e-2,1e-3,default,nhcw-nhwc/img2col,True,False,False,"","macos"
|
||||
resnet18,linalg,torch,1e-2,1e-3,default,None,True,True,False,"","macos"
|
||||
resnet101,linalg,torch,1e-2,1e-3,default,nhcw-nhwc/img2col,True,True,True,"","macos"
|
||||
resnet18,linalg,torch,1e-2,1e-3,default,None,True,True,True,"","macos"
|
||||
resnet50,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,False,False,False,"","macos"
|
||||
resnet50_fp16,linalg,torch,1e-2,1e-2,default,nhcw-nhwc/img2col,True,True,True,"Numerics issues, awaiting cuda-independent fp16 integration",""
|
||||
squeezenet1_0,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,False,False,False,"","macos"
|
||||
wide_resnet50_2,linalg,torch,1e-2,1e-3,default,nhcw-nhwc/img2col,True,False,False,"","macos"
|
||||
mnasnet1_0,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,True,True,True,"","macos"
|
||||
wide_resnet50_2,linalg,torch,1e-2,1e-3,default,nhcw-nhwc/img2col,True,True,True,"","macos"
|
||||
mnasnet1_0,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,False,False,False,"","macos"
|
||||
efficientnet_b0,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,True,True,True,"https://github.com/nod-ai/SHARK/issues/1487","macos"
|
||||
efficientnet_b7,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,True,True,True,"https://github.com/nod-ai/SHARK/issues/1487","macos"
|
||||
t5-base,linalg,torch,1e-2,1e-3,default,None,True,True,True,"Inputs for seq2seq models in torch currently unsupported.","macos"
|
||||
t5-large,linalg,torch,1e-2,1e-3,default,None,True,True,True,"Inputs for seq2seq models in torch currently unsupported","macos"
|
||||
t5-base,linalg,torch,1e-2,1e-3,default,None,True,True,True,"","macos"
|
||||
t5-large,linalg,torch,1e-2,1e-3,default,None,True,True,True,"","macos"
|
||||
|
||||
|
@@ -71,4 +71,4 @@ if __name__ == "__main__":
|
||||
args.max_seq_len,
|
||||
print_intermediate_results=False,
|
||||
)
|
||||
print("reponse: {}".format("".join(response)))
|
||||
print("response: {}".format("".join(response)))
|
||||
|
||||
@@ -47,9 +47,8 @@ def save_torch_model(torch_model_list, local_tank_cache, import_args):
|
||||
is_dynamic = row[3]
|
||||
mlir_type = row[4]
|
||||
is_decompose = row[5]
|
||||
|
||||
tracing_required = True
|
||||
is_dynamic = False if is_dynamic == "False" else True
|
||||
tracing_required = False if tracing_required == "False" else True
|
||||
is_dynamic = False
|
||||
print("generating artifacts for: " + torch_model_name)
|
||||
model = None
|
||||
input = None
|
||||
|
||||
Reference in New Issue
Block a user