Merge branch 'main' into vicuna-py-sharded-update

This commit is contained in:
Eliasj42
2023-12-03 11:23:15 -08:00
committed by GitHub
44 changed files with 3005 additions and 824 deletions

View File

@@ -112,7 +112,7 @@ jobs:
cd $GITHUB_WORKSPACE
PYTHON=python${{ matrix.python-version }} IMPORTER=1 ./setup_venv.sh
source shark.venv/bin/activate
pytest --forked --benchmark=native --ci --ci_sha=${SHORT_SHA} --update_tank --tank_url="gs://shark_tank/nightly/" -k cpu
pytest --benchmark=native --update_tank -k cpu
gsutil cp ./bench_results.csv gs://shark-public/builder/bench_results/${DATE}/bench_results_cpu_${SHORT_SHA}.csv
gsutil cp gs://shark-public/builder/bench_results/${DATE}/bench_results_cpu_${SHORT_SHA}.csv gs://shark-public/builder/bench_results/latest/bench_results_cpu_latest.csv
python build_tools/vicuna_testing.py
@@ -121,9 +121,9 @@ jobs:
if: matrix.suite == 'cuda'
run: |
cd $GITHUB_WORKSPACE
PYTHON=python${{ matrix.python-version }} ./setup_venv.sh
PYTHON=python${{ matrix.python-version }} IMPORTER=1 ./setup_venv.sh
source shark.venv/bin/activate
pytest --forked --benchmark=native --ci --ci_sha=${SHORT_SHA} --update_tank --tank_url="gs://shark_tank/nightly/" -k cuda
pytest --benchmark=native --update_tank -k cuda
gsutil cp ./bench_results.csv gs://shark-public/builder/bench_results/${DATE}/bench_results_cuda_${SHORT_SHA}.csv
gsutil cp gs://shark-public/builder/bench_results/${DATE}/bench_results_cuda_${SHORT_SHA}.csv gs://shark-public/builder/bench_results/latest/bench_results_cuda_latest.csv
# Disabled due to black image bug
@@ -144,10 +144,10 @@ jobs:
if: matrix.suite == 'vulkan' && matrix.os == 'a100'
run: |
cd $GITHUB_WORKSPACE
PYTHON=python${{ matrix.python-version }} ./setup_venv.sh
PYTHON=python${{ matrix.python-version }} IMPORTER=1 ./setup_venv.sh
source shark.venv/bin/activate
pytest --forked --benchmark="native" --ci --ci_sha=${SHORT_SHA} --update_tank --tank_url="gs://shark_tank/nightly/" -k vulkan
python build_tools/stable_diffusion_testing.py --device=vulkan
pytest --update_tank -k vulkan
python build_tools/stable_diffusion_testing.py --device=vulkan --no-exit_on_fail
- name: Validate Vulkan Models (Windows)
if: matrix.suite == 'vulkan' && matrix.os == '7950x'

View File

@@ -2425,4 +2425,4 @@ if __name__ == "__main__":
if args.enable_microbenchmark:
print("\n### Final Statistics ###")
print_aggregate_stats(benchmark_run_infos)
print_aggregate_stats(benchmark_run_infos)

View File

@@ -69,91 +69,7 @@ class CompiledLMHeadEmbeddingLayer(torch.nn.Module):
return torch.tensor(new_hidden_states)
class DecoderLayer(torch.nn.Module):
def __init__(self, decoder_layer_model, falcon_variant):
super().__init__()
self.model = decoder_layer_model
def forward(self, hidden_states, attention_mask):
output = self.model.forward(
hidden_states=hidden_states,
alibi=None,
attention_mask=attention_mask,
use_cache=True,
)
return (output[0], output[1][0], output[1][1])
class CompiledDecoderLayer(torch.nn.Module):
def __init__(
self, layer_id, device_idx, falcon_variant, device, precision
):
super().__init__()
self.layer_id = layer_id
self.device_index = device_idx
self.falcon_variant = falcon_variant
self.device = device
self.precision = precision
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: torch.Tensor,
alibi: torch.Tensor = None,
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
head_mask: Optional[torch.Tensor] = None,
use_cache: bool = False,
output_attentions: bool = False,
):
import gc
torch.cuda.empty_cache()
gc.collect()
from pathlib import Path
from apps.language_models.utils import get_vmfb_from_path
self.falcon_vmfb_path = Path(
f"falcon_{self.falcon_variant}_layer_{self.layer_id}_{self.precision}_{self.device}.vmfb"
)
print("vmfb path for layer: ", self.falcon_vmfb_path)
self.model = get_vmfb_from_path(
self.falcon_vmfb_path,
self.device,
"linalg",
device_id=self.device_index,
)
if self.model is None:
raise ValueError("Layer vmfb not found")
hidden_states = hidden_states.to(torch.float32).detach().numpy()
attention_mask = attention_mask.detach().numpy()
if alibi is not None or layer_past is not None:
raise ValueError("Past Key Values and alibi should be None")
else:
new_hidden_states, pkv1, pkv2 = self.model(
"forward",
(
hidden_states,
attention_mask,
),
)
del self.model
return tuple(
[
torch.tensor(new_hidden_states),
tuple(
[
torch.tensor(pkv1),
torch.tensor(pkv2),
]
),
]
)
class EightDecoderLayer(torch.nn.Module):
class FourWayShardingDecoderLayer(torch.nn.Module):
def __init__(self, decoder_layer_model, falcon_variant):
super().__init__()
self.model = decoder_layer_model
@@ -175,163 +91,78 @@ class EightDecoderLayer(torch.nn.Module):
outputs[-1][1],
)
)
if self.falcon_variant == "7b":
(
(new_pkv00, new_pkv01),
(new_pkv10, new_pkv11),
(new_pkv20, new_pkv21),
(new_pkv30, new_pkv31),
(new_pkv40, new_pkv41),
(new_pkv50, new_pkv51),
(new_pkv60, new_pkv61),
(new_pkv70, new_pkv71),
) = new_pkvs
result = (
hidden_states,
new_pkv00,
new_pkv01,
new_pkv10,
new_pkv11,
new_pkv20,
new_pkv21,
new_pkv30,
new_pkv31,
new_pkv40,
new_pkv41,
new_pkv50,
new_pkv51,
new_pkv60,
new_pkv61,
new_pkv70,
new_pkv71,
)
elif self.falcon_variant == "40b":
(
(new_pkv00, new_pkv01),
(new_pkv10, new_pkv11),
(new_pkv20, new_pkv21),
(new_pkv30, new_pkv31),
(new_pkv40, new_pkv41),
(new_pkv50, new_pkv51),
(new_pkv60, new_pkv61),
(new_pkv70, new_pkv71),
(new_pkv80, new_pkv81),
(new_pkv90, new_pkv91),
(new_pkv100, new_pkv101),
(new_pkv110, new_pkv111),
(new_pkv120, new_pkv121),
(new_pkv130, new_pkv131),
(new_pkv140, new_pkv141),
) = new_pkvs
result = (
hidden_states,
new_pkv00,
new_pkv01,
new_pkv10,
new_pkv11,
new_pkv20,
new_pkv21,
new_pkv30,
new_pkv31,
new_pkv40,
new_pkv41,
new_pkv50,
new_pkv51,
new_pkv60,
new_pkv61,
new_pkv70,
new_pkv71,
new_pkv80,
new_pkv81,
new_pkv90,
new_pkv91,
new_pkv100,
new_pkv101,
new_pkv110,
new_pkv111,
new_pkv120,
new_pkv121,
new_pkv130,
new_pkv131,
new_pkv140,
new_pkv141,
)
elif self.falcon_variant == "180b":
(
(new_pkv00, new_pkv01),
(new_pkv10, new_pkv11),
(new_pkv20, new_pkv21),
(new_pkv30, new_pkv31),
(new_pkv40, new_pkv41),
(new_pkv50, new_pkv51),
(new_pkv60, new_pkv61),
(new_pkv70, new_pkv71),
(new_pkv80, new_pkv81),
(new_pkv90, new_pkv91),
(new_pkv100, new_pkv101),
(new_pkv110, new_pkv111),
(new_pkv120, new_pkv121),
(new_pkv130, new_pkv131),
(new_pkv140, new_pkv141),
(new_pkv150, new_pkv151),
(new_pkv160, new_pkv161),
(new_pkv170, new_pkv171),
(new_pkv180, new_pkv181),
(new_pkv190, new_pkv191),
) = new_pkvs
result = (
hidden_states,
new_pkv00,
new_pkv01,
new_pkv10,
new_pkv11,
new_pkv20,
new_pkv21,
new_pkv30,
new_pkv31,
new_pkv40,
new_pkv41,
new_pkv50,
new_pkv51,
new_pkv60,
new_pkv61,
new_pkv70,
new_pkv71,
new_pkv80,
new_pkv81,
new_pkv90,
new_pkv91,
new_pkv100,
new_pkv101,
new_pkv110,
new_pkv111,
new_pkv120,
new_pkv121,
new_pkv130,
new_pkv131,
new_pkv140,
new_pkv141,
new_pkv150,
new_pkv151,
new_pkv160,
new_pkv161,
new_pkv170,
new_pkv171,
new_pkv180,
new_pkv181,
new_pkv190,
new_pkv191,
)
else:
raise ValueError(
"Unsupported Falcon variant: ", self.falcon_variant
)
(
(new_pkv00, new_pkv01),
(new_pkv10, new_pkv11),
(new_pkv20, new_pkv21),
(new_pkv30, new_pkv31),
(new_pkv40, new_pkv41),
(new_pkv50, new_pkv51),
(new_pkv60, new_pkv61),
(new_pkv70, new_pkv71),
(new_pkv80, new_pkv81),
(new_pkv90, new_pkv91),
(new_pkv100, new_pkv101),
(new_pkv110, new_pkv111),
(new_pkv120, new_pkv121),
(new_pkv130, new_pkv131),
(new_pkv140, new_pkv141),
(new_pkv150, new_pkv151),
(new_pkv160, new_pkv161),
(new_pkv170, new_pkv171),
(new_pkv180, new_pkv181),
(new_pkv190, new_pkv191),
) = new_pkvs
result = (
hidden_states,
new_pkv00,
new_pkv01,
new_pkv10,
new_pkv11,
new_pkv20,
new_pkv21,
new_pkv30,
new_pkv31,
new_pkv40,
new_pkv41,
new_pkv50,
new_pkv51,
new_pkv60,
new_pkv61,
new_pkv70,
new_pkv71,
new_pkv80,
new_pkv81,
new_pkv90,
new_pkv91,
new_pkv100,
new_pkv101,
new_pkv110,
new_pkv111,
new_pkv120,
new_pkv121,
new_pkv130,
new_pkv131,
new_pkv140,
new_pkv141,
new_pkv150,
new_pkv151,
new_pkv160,
new_pkv161,
new_pkv170,
new_pkv171,
new_pkv180,
new_pkv181,
new_pkv190,
new_pkv191,
)
return result
class CompiledEightDecoderLayer(torch.nn.Module):
class CompiledFourWayShardingDecoderLayer(torch.nn.Module):
def __init__(
self, layer_id, device_idx, falcon_variant, device, precision
self, layer_id, device_idx, falcon_variant, device, precision, model
):
super().__init__()
self.layer_id = layer_id
@@ -339,12 +170,14 @@ class CompiledEightDecoderLayer(torch.nn.Module):
self.falcon_variant = falcon_variant
self.device = device
self.precision = precision
self.model = model
def forward(
self,
hidden_states: torch.Tensor,
alibi: Optional[torch.Tensor],
attention_mask: torch.Tensor,
alibi: torch.Tensor = None,
position_ids: Optional[torch.LongTensor] = None,
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
head_mask: Optional[torch.Tensor] = None,
use_cache: bool = False,
@@ -354,24 +187,12 @@ class CompiledEightDecoderLayer(torch.nn.Module):
torch.cuda.empty_cache()
gc.collect()
from pathlib import Path
from apps.language_models.utils import get_vmfb_from_path
self.falcon_vmfb_path = Path(
f"falcon_{self.falcon_variant}_layer_{self.layer_id}_{self.precision}_{self.device}.vmfb"
)
print("vmfb path for layer: ", self.falcon_vmfb_path)
self.model = get_vmfb_from_path(
self.falcon_vmfb_path,
self.device,
"linalg",
device_id=self.device_index,
)
if self.model is None:
raise ValueError("Layer vmfb not found")
hidden_states = hidden_states.to(torch.float32).detach().numpy()
attention_mask = attention_mask.detach().numpy()
attention_mask = attention_mask.to(torch.float32).detach().numpy()
if alibi is not None or layer_past is not None:
raise ValueError("Past Key Values and alibi should be None")
@@ -383,196 +204,452 @@ class CompiledEightDecoderLayer(torch.nn.Module):
attention_mask,
),
)
del self.model
if self.falcon_variant == "7b":
result = (
torch.tensor(output[0]),
(
torch.tensor(output[1]),
torch.tensor(output[2]),
),
(
torch.tensor(output[3]),
torch.tensor(output[4]),
),
(
torch.tensor(output[5]),
torch.tensor(output[6]),
),
(
torch.tensor(output[7]),
torch.tensor(output[8]),
),
(
torch.tensor(output[9]),
torch.tensor(output[10]),
),
(
torch.tensor(output[11]),
torch.tensor(output[12]),
),
(
torch.tensor(output[13]),
torch.tensor(output[14]),
),
(
torch.tensor(output[15]),
torch.tensor(output[16]),
),
result = (
torch.tensor(output[0]),
(
torch.tensor(output[1]),
torch.tensor(output[2]),
),
(
torch.tensor(output[3]),
torch.tensor(output[4]),
),
(
torch.tensor(output[5]),
torch.tensor(output[6]),
),
(
torch.tensor(output[7]),
torch.tensor(output[8]),
),
(
torch.tensor(output[9]),
torch.tensor(output[10]),
),
(
torch.tensor(output[11]),
torch.tensor(output[12]),
),
(
torch.tensor(output[13]),
torch.tensor(output[14]),
),
(
torch.tensor(output[15]),
torch.tensor(output[16]),
),
(
torch.tensor(output[17]),
torch.tensor(output[18]),
),
(
torch.tensor(output[19]),
torch.tensor(output[20]),
),
(
torch.tensor(output[21]),
torch.tensor(output[22]),
),
(
torch.tensor(output[23]),
torch.tensor(output[24]),
),
(
torch.tensor(output[25]),
torch.tensor(output[26]),
),
(
torch.tensor(output[27]),
torch.tensor(output[28]),
),
(
torch.tensor(output[29]),
torch.tensor(output[30]),
),
(
torch.tensor(output[31]),
torch.tensor(output[32]),
),
(
torch.tensor(output[33]),
torch.tensor(output[34]),
),
(
torch.tensor(output[35]),
torch.tensor(output[36]),
),
(
torch.tensor(output[37]),
torch.tensor(output[38]),
),
(
torch.tensor(output[39]),
torch.tensor(output[40]),
),
)
return result
class TwoWayShardingDecoderLayer(torch.nn.Module):
def __init__(self, decoder_layer_model, falcon_variant):
super().__init__()
self.model = decoder_layer_model
self.falcon_variant = falcon_variant
def forward(self, hidden_states, attention_mask):
new_pkvs = []
for layer in self.model:
outputs = layer(
hidden_states=hidden_states,
alibi=None,
attention_mask=attention_mask,
use_cache=True,
)
elif self.falcon_variant == "40b":
result = (
torch.tensor(output[0]),
hidden_states = outputs[0]
new_pkvs.append(
(
torch.tensor(output[1]),
torch.tensor(output[2]),
),
(
torch.tensor(output[3]),
torch.tensor(output[4]),
),
(
torch.tensor(output[5]),
torch.tensor(output[6]),
),
(
torch.tensor(output[7]),
torch.tensor(output[8]),
),
(
torch.tensor(output[9]),
torch.tensor(output[10]),
),
(
torch.tensor(output[11]),
torch.tensor(output[12]),
),
(
torch.tensor(output[13]),
torch.tensor(output[14]),
),
(
torch.tensor(output[15]),
torch.tensor(output[16]),
),
(
torch.tensor(output[17]),
torch.tensor(output[18]),
),
(
torch.tensor(output[19]),
torch.tensor(output[20]),
),
(
torch.tensor(output[21]),
torch.tensor(output[22]),
),
(
torch.tensor(output[23]),
torch.tensor(output[24]),
),
(
torch.tensor(output[25]),
torch.tensor(output[26]),
),
(
torch.tensor(output[27]),
torch.tensor(output[28]),
),
(
torch.tensor(output[29]),
torch.tensor(output[30]),
),
)
elif self.falcon_variant == "180b":
result = (
torch.tensor(output[0]),
(
torch.tensor(output[1]),
torch.tensor(output[2]),
),
(
torch.tensor(output[3]),
torch.tensor(output[4]),
),
(
torch.tensor(output[5]),
torch.tensor(output[6]),
),
(
torch.tensor(output[7]),
torch.tensor(output[8]),
),
(
torch.tensor(output[9]),
torch.tensor(output[10]),
),
(
torch.tensor(output[11]),
torch.tensor(output[12]),
),
(
torch.tensor(output[13]),
torch.tensor(output[14]),
),
(
torch.tensor(output[15]),
torch.tensor(output[16]),
),
(
torch.tensor(output[17]),
torch.tensor(output[18]),
),
(
torch.tensor(output[19]),
torch.tensor(output[20]),
),
(
torch.tensor(output[21]),
torch.tensor(output[22]),
),
(
torch.tensor(output[23]),
torch.tensor(output[24]),
),
(
torch.tensor(output[25]),
torch.tensor(output[26]),
),
(
torch.tensor(output[27]),
torch.tensor(output[28]),
),
(
torch.tensor(output[29]),
torch.tensor(output[30]),
),
(
torch.tensor(output[31]),
torch.tensor(output[32]),
),
(
torch.tensor(output[33]),
torch.tensor(output[34]),
),
(
torch.tensor(output[35]),
torch.tensor(output[36]),
),
(
torch.tensor(output[37]),
torch.tensor(output[38]),
),
(
torch.tensor(output[39]),
torch.tensor(output[40]),
),
outputs[-1][0],
outputs[-1][1],
)
)
(
(new_pkv00, new_pkv01),
(new_pkv10, new_pkv11),
(new_pkv20, new_pkv21),
(new_pkv30, new_pkv31),
(new_pkv40, new_pkv41),
(new_pkv50, new_pkv51),
(new_pkv60, new_pkv61),
(new_pkv70, new_pkv71),
(new_pkv80, new_pkv81),
(new_pkv90, new_pkv91),
(new_pkv100, new_pkv101),
(new_pkv110, new_pkv111),
(new_pkv120, new_pkv121),
(new_pkv130, new_pkv131),
(new_pkv140, new_pkv141),
(new_pkv150, new_pkv151),
(new_pkv160, new_pkv161),
(new_pkv170, new_pkv171),
(new_pkv180, new_pkv181),
(new_pkv190, new_pkv191),
(new_pkv200, new_pkv201),
(new_pkv210, new_pkv211),
(new_pkv220, new_pkv221),
(new_pkv230, new_pkv231),
(new_pkv240, new_pkv241),
(new_pkv250, new_pkv251),
(new_pkv260, new_pkv261),
(new_pkv270, new_pkv271),
(new_pkv280, new_pkv281),
(new_pkv290, new_pkv291),
(new_pkv300, new_pkv301),
(new_pkv310, new_pkv311),
(new_pkv320, new_pkv321),
(new_pkv330, new_pkv331),
(new_pkv340, new_pkv341),
(new_pkv350, new_pkv351),
(new_pkv360, new_pkv361),
(new_pkv370, new_pkv371),
(new_pkv380, new_pkv381),
(new_pkv390, new_pkv391),
) = new_pkvs
result = (
hidden_states,
new_pkv00,
new_pkv01,
new_pkv10,
new_pkv11,
new_pkv20,
new_pkv21,
new_pkv30,
new_pkv31,
new_pkv40,
new_pkv41,
new_pkv50,
new_pkv51,
new_pkv60,
new_pkv61,
new_pkv70,
new_pkv71,
new_pkv80,
new_pkv81,
new_pkv90,
new_pkv91,
new_pkv100,
new_pkv101,
new_pkv110,
new_pkv111,
new_pkv120,
new_pkv121,
new_pkv130,
new_pkv131,
new_pkv140,
new_pkv141,
new_pkv150,
new_pkv151,
new_pkv160,
new_pkv161,
new_pkv170,
new_pkv171,
new_pkv180,
new_pkv181,
new_pkv190,
new_pkv191,
new_pkv200,
new_pkv201,
new_pkv210,
new_pkv211,
new_pkv220,
new_pkv221,
new_pkv230,
new_pkv231,
new_pkv240,
new_pkv241,
new_pkv250,
new_pkv251,
new_pkv260,
new_pkv261,
new_pkv270,
new_pkv271,
new_pkv280,
new_pkv281,
new_pkv290,
new_pkv291,
new_pkv300,
new_pkv301,
new_pkv310,
new_pkv311,
new_pkv320,
new_pkv321,
new_pkv330,
new_pkv331,
new_pkv340,
new_pkv341,
new_pkv350,
new_pkv351,
new_pkv360,
new_pkv361,
new_pkv370,
new_pkv371,
new_pkv380,
new_pkv381,
new_pkv390,
new_pkv391,
)
return result
class CompiledTwoWayShardingDecoderLayer(torch.nn.Module):
def __init__(
self, layer_id, device_idx, falcon_variant, device, precision, model
):
super().__init__()
self.layer_id = layer_id
self.device_index = device_idx
self.falcon_variant = falcon_variant
self.device = device
self.precision = precision
self.model = model
def forward(
self,
hidden_states: torch.Tensor,
alibi: Optional[torch.Tensor],
attention_mask: torch.Tensor,
position_ids: Optional[torch.LongTensor] = None,
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
head_mask: Optional[torch.Tensor] = None,
use_cache: bool = False,
output_attentions: bool = False,
):
import gc
torch.cuda.empty_cache()
gc.collect()
if self.model is None:
raise ValueError("Layer vmfb not found")
hidden_states = hidden_states.to(torch.float32).detach().numpy()
attention_mask = attention_mask.to(torch.float32).detach().numpy()
if alibi is not None or layer_past is not None:
raise ValueError("Past Key Values and alibi should be None")
else:
raise ValueError(
"Unsupported Falcon variant: ", self.falcon_variant
output = self.model(
"forward",
(
hidden_states,
attention_mask,
),
)
result = (
torch.tensor(output[0]),
(
torch.tensor(output[1]),
torch.tensor(output[2]),
),
(
torch.tensor(output[3]),
torch.tensor(output[4]),
),
(
torch.tensor(output[5]),
torch.tensor(output[6]),
),
(
torch.tensor(output[7]),
torch.tensor(output[8]),
),
(
torch.tensor(output[9]),
torch.tensor(output[10]),
),
(
torch.tensor(output[11]),
torch.tensor(output[12]),
),
(
torch.tensor(output[13]),
torch.tensor(output[14]),
),
(
torch.tensor(output[15]),
torch.tensor(output[16]),
),
(
torch.tensor(output[17]),
torch.tensor(output[18]),
),
(
torch.tensor(output[19]),
torch.tensor(output[20]),
),
(
torch.tensor(output[21]),
torch.tensor(output[22]),
),
(
torch.tensor(output[23]),
torch.tensor(output[24]),
),
(
torch.tensor(output[25]),
torch.tensor(output[26]),
),
(
torch.tensor(output[27]),
torch.tensor(output[28]),
),
(
torch.tensor(output[29]),
torch.tensor(output[30]),
),
(
torch.tensor(output[31]),
torch.tensor(output[32]),
),
(
torch.tensor(output[33]),
torch.tensor(output[34]),
),
(
torch.tensor(output[35]),
torch.tensor(output[36]),
),
(
torch.tensor(output[37]),
torch.tensor(output[38]),
),
(
torch.tensor(output[39]),
torch.tensor(output[40]),
),
(
torch.tensor(output[41]),
torch.tensor(output[42]),
),
(
torch.tensor(output[43]),
torch.tensor(output[44]),
),
(
torch.tensor(output[45]),
torch.tensor(output[46]),
),
(
torch.tensor(output[47]),
torch.tensor(output[48]),
),
(
torch.tensor(output[49]),
torch.tensor(output[50]),
),
(
torch.tensor(output[51]),
torch.tensor(output[52]),
),
(
torch.tensor(output[53]),
torch.tensor(output[54]),
),
(
torch.tensor(output[55]),
torch.tensor(output[56]),
),
(
torch.tensor(output[57]),
torch.tensor(output[58]),
),
(
torch.tensor(output[59]),
torch.tensor(output[60]),
),
(
torch.tensor(output[61]),
torch.tensor(output[62]),
),
(
torch.tensor(output[63]),
torch.tensor(output[64]),
),
(
torch.tensor(output[65]),
torch.tensor(output[66]),
),
(
torch.tensor(output[67]),
torch.tensor(output[68]),
),
(
torch.tensor(output[69]),
torch.tensor(output[70]),
),
(
torch.tensor(output[71]),
torch.tensor(output[72]),
),
(
torch.tensor(output[73]),
torch.tensor(output[74]),
),
(
torch.tensor(output[75]),
torch.tensor(output[76]),
),
(
torch.tensor(output[77]),
torch.tensor(output[78]),
),
(
torch.tensor(output[79]),
torch.tensor(output[80]),
),
)
return result

View File

@@ -6,10 +6,10 @@ from apps.language_models.src.model_wrappers.falcon_sharded_model import (
CompiledLNFEmbeddingLayer,
LMHeadEmbeddingLayer,
CompiledLMHeadEmbeddingLayer,
DecoderLayer,
EightDecoderLayer,
CompiledDecoderLayer,
CompiledEightDecoderLayer,
FourWayShardingDecoderLayer,
TwoWayShardingDecoderLayer,
CompiledFourWayShardingDecoderLayer,
CompiledTwoWayShardingDecoderLayer,
ShardedFalconModel,
)
from apps.language_models.src.pipelines.SharkLLMBase import SharkLLMBase
@@ -94,6 +94,13 @@ parser.add_argument(
action=argparse.BooleanOptionalAction,
help="Run model as sharded",
)
parser.add_argument(
"--num_shards",
type=int,
default=4,
choices=[2, 4],
help="Number of shards.",
)
class ShardedFalcon(SharkLLMBase):
@@ -122,6 +129,10 @@ class ShardedFalcon(SharkLLMBase):
--hf_auth_token flag. You can ask for the access to the model
here: https://huggingface.co/tiiuae/falcon-180B-chat."""
)
if args.sharded and "180b" not in self.model_name:
raise ValueError("Sharding supported only for Falcon-180B")
self.hf_auth_token = hf_auth_token
self.max_padding_length = 100
self.device = device
@@ -131,7 +142,7 @@ class ShardedFalcon(SharkLLMBase):
self.debug = debug
self.tokenizer = self.get_tokenizer()
self.src_model = self.get_src_model()
self.shark_model = self.compile(compressed=args.compressed)
self.shark_model = self.compile()
def get_tokenizer(self):
tokenizer = AutoTokenizer.from_pretrained(
@@ -146,20 +157,17 @@ class ShardedFalcon(SharkLLMBase):
def get_src_model(self):
print("Loading src model: ", self.model_name)
kwargs = {
"torch_dtype": torch.float,
"torch_dtype": torch.float32,
"trust_remote_code": True,
"token": self.hf_auth_token,
}
if self.precision == "int4":
quantization_config = GPTQConfig(bits=4, disable_exllama=True)
kwargs["quantization_config"] = quantization_config
kwargs["load_gptq_on_cpu"] = True
kwargs["device_map"] = "cpu"
falcon_model = AutoModelForCausalLM.from_pretrained(
self.hf_model_path, **kwargs
)
if self.precision == "int4":
falcon_model = falcon_model.to(torch.float32)
return falcon_model
def compile_layer(
@@ -225,7 +233,7 @@ class ShardedFalcon(SharkLLMBase):
elif layer_id in ["ln_f", "lm_head"]:
f16_input_mask = [True]
elif "_" in layer_id or type(layer_id) == int:
f16_input_mask = [True, False]
f16_input_mask = [True, True]
else:
raise ValueError("Unsupported layer: ", layer_id)
@@ -288,28 +296,16 @@ class ShardedFalcon(SharkLLMBase):
return shark_module, device_idx
def compile(self, compressed=False):
def compile(self):
sample_input_ids = torch.zeros([100], dtype=torch.int64)
sample_attention_mask = torch.zeros(
[1, 1, 100, 100], dtype=torch.float32
)
num_group_layers = 1
if "7b" in self.model_name:
num_in_features = 4544
if compressed:
num_group_layers = 8
elif "40b" in self.model_name:
num_in_features = 8192
if compressed:
num_group_layers = 15
else:
num_in_features = 14848
sample_attention_mask = sample_attention_mask.to(dtype=torch.bool)
if compressed:
num_group_layers = 20
num_group_layers = int(
20 * (4 / args.num_shards)
) # 4 is the number of default shards
sample_hidden_states = torch.zeros(
[1, 100, num_in_features], dtype=torch.float32
[1, 100, 14848], dtype=torch.float32
)
# Determine number of available devices
@@ -319,6 +315,10 @@ class ShardedFalcon(SharkLLMBase):
haldriver = ireert.get_driver(self.device)
num_devices = len(haldriver.query_available_devices())
if num_devices < 2:
raise ValueError(
"Cannot run Falcon-180B on a single ROCM device."
)
lm_head = LMHeadEmbeddingLayer(self.src_model.lm_head)
print("Compiling Layer lm_head")
@@ -326,7 +326,9 @@ class ShardedFalcon(SharkLLMBase):
lm_head,
[sample_hidden_states],
"lm_head",
device_idx=0 % num_devices if self.device == "rocm" else None,
device_idx=(0 % num_devices) % args.num_shards
if self.device == "rocm"
else None,
)
shark_lm_head = CompiledLMHeadEmbeddingLayer(shark_lm_head)
@@ -338,7 +340,9 @@ class ShardedFalcon(SharkLLMBase):
word_embedding,
[sample_input_ids],
"word_embeddings",
device_idx=1 % num_devices if self.device == "rocm" else None,
device_idx=(1 % num_devices) % args.num_shards
if self.device == "rocm"
else None,
)
shark_word_embedding = CompiledWordEmbeddingsLayer(
shark_word_embedding
@@ -350,7 +354,9 @@ class ShardedFalcon(SharkLLMBase):
ln_f,
[sample_hidden_states],
"ln_f",
device_idx=2 % num_devices if self.device == "rocm" else None,
device_idx=(2 % num_devices) % args.num_shards
if self.device == "rocm"
else None,
)
shark_ln_f = CompiledLNFEmbeddingLayer(shark_ln_f)
@@ -360,24 +366,21 @@ class ShardedFalcon(SharkLLMBase):
):
device_idx = i % num_devices if self.device == "rocm" else None
layer_id = i
pytorch_class = DecoderLayer
compiled_class = CompiledDecoderLayer
if compressed:
layer_id = (
str(i * num_group_layers)
+ "_"
+ str((i + 1) * num_group_layers)
)
pytorch_class = EightDecoderLayer
compiled_class = CompiledEightDecoderLayer
layer_id = (
str(i * num_group_layers)
+ "_"
+ str((i + 1) * num_group_layers)
)
pytorch_class = FourWayShardingDecoderLayer
compiled_class = CompiledFourWayShardingDecoderLayer
if args.num_shards == 2:
pytorch_class = TwoWayShardingDecoderLayer
compiled_class = CompiledTwoWayShardingDecoderLayer
print("Compiling Layer {}".format(layer_id))
if compressed:
layer_i = self.src_model.transformer.h[
i * num_group_layers : (i + 1) * num_group_layers
]
else:
layer_i = self.src_model.transformer.h[i]
layer_i = self.src_model.transformer.h[
i * num_group_layers : (i + 1) * num_group_layers
]
pytorch_layer_i = pytorch_class(
layer_i, args.falcon_variant_to_use
@@ -388,13 +391,13 @@ class ShardedFalcon(SharkLLMBase):
layer_id,
device_idx=device_idx,
)
del shark_module
shark_layer_i = compiled_class(
layer_id,
device_idx,
args.falcon_variant_to_use,
self.device,
self.precision,
shark_module,
)
shark_layers.append(shark_layer_i)
@@ -668,20 +671,17 @@ class UnshardedFalcon(SharkLLMBase):
def get_src_model(self):
print("Loading src model: ", self.model_name)
kwargs = {
"torch_dtype": torch.float,
"torch_dtype": torch.float32,
"trust_remote_code": True,
"token": self.hf_auth_token,
}
if self.precision == "int4":
quantization_config = GPTQConfig(bits=4, disable_exllama=True)
kwargs["quantization_config"] = quantization_config
kwargs["load_gptq_on_cpu"] = True
kwargs["device_map"] = "cpu"
falcon_model = AutoModelForCausalLM.from_pretrained(
self.hf_model_path, **kwargs
)
if self.precision == "int4":
falcon_model = falcon_model.to(torch.float32)
return falcon_model
def compile(self):

View File

@@ -105,6 +105,7 @@ def main():
cpu_scheduling,
args.max_embeddings_multiples,
use_stencil=use_stencil,
control_mode=args.control_mode,
)
total_time = time.time() - start_time
text_output = f"prompt={args.prompts}"

View File

@@ -0,0 +1,96 @@
import torch
import time
from apps.stable_diffusion.src import (
args,
Text2ImageSDXLPipeline,
get_schedulers,
set_init_device_flags,
utils,
clear_all,
save_output_img,
)
def main():
if args.clear_all:
clear_all()
# TODO: prompt_embeds and text_embeds form base_model.json requires fixing
args.precision = "fp16"
args.height = 1024
args.width = 1024
args.max_length = 77
args.scheduler = "DDIM"
print(
"Using default supported configuration for SDXL :-\nprecision=fp16, width*height= 1024*1024, max_length=77 and scheduler=DDIM"
)
dtype = torch.float32 if args.precision == "fp32" else torch.half
cpu_scheduling = not args.scheduler.startswith("Shark")
set_init_device_flags()
schedulers = get_schedulers(args.hf_model_id)
scheduler_obj = schedulers[args.scheduler]
seed = args.seed
txt2img_obj = Text2ImageSDXLPipeline.from_pretrained(
scheduler=scheduler_obj,
import_mlir=args.import_mlir,
model_id=args.hf_model_id,
ckpt_loc=args.ckpt_loc,
precision=args.precision,
max_length=args.max_length,
batch_size=args.batch_size,
height=args.height,
width=args.width,
use_base_vae=args.use_base_vae,
use_tuned=args.use_tuned,
custom_vae=args.custom_vae,
low_cpu_mem_usage=args.low_cpu_mem_usage,
debug=args.import_debug if args.import_mlir else False,
use_lora=args.use_lora,
use_quantize=args.use_quantize,
ondemand=args.ondemand,
)
seeds = utils.batch_seeds(seed, args.batch_count, args.repeatable_seeds)
for current_batch in range(args.batch_count):
start_time = time.time()
generated_imgs = txt2img_obj.generate_images(
args.prompts,
args.negative_prompts,
args.batch_size,
args.height,
args.width,
args.steps,
args.guidance_scale,
seeds[current_batch],
args.max_length,
dtype,
args.use_base_vae,
cpu_scheduling,
args.max_embeddings_multiples,
)
total_time = time.time() - start_time
text_output = f"prompt={args.prompts}"
text_output += f"\nnegative prompt={args.negative_prompts}"
text_output += (
f"\nmodel_id={args.hf_model_id}, ckpt_loc={args.ckpt_loc}"
)
text_output += f"\nscheduler={args.scheduler}, device={args.device}"
text_output += (
f"\nsteps={args.steps}, guidance_scale={args.guidance_scale},"
)
text_output += (
f"seed={seeds[current_batch]}, size={args.height}x{args.width}"
)
text_output += (
f", batch size={args.batch_size}, max_length={args.max_length}"
)
# TODO: if using --batch_count=x txt2img_obj.log will output on each display every iteration infos from the start
text_output += txt2img_obj.log
text_output += f"\nTotal image generation time: {total_time:.4f}sec"
save_output_img(generated_imgs[0], seed)
print(text_output)
if __name__ == "__main__":
main()

View File

@@ -9,6 +9,7 @@ from apps.stable_diffusion.src.utils import (
)
from apps.stable_diffusion.src.pipelines import (
Text2ImagePipeline,
Text2ImageSDXLPipeline,
Image2ImagePipeline,
InpaintPipeline,
OutpaintPipeline,

View File

@@ -1,5 +1,5 @@
from diffusers import AutoencoderKL, UNet2DConditionModel, ControlNetModel
from transformers import CLIPTextModel
from transformers import CLIPTextModel, CLIPTextModelWithProjection
from collections import defaultdict
from pathlib import Path
import torch
@@ -24,6 +24,8 @@ from apps.stable_diffusion.src.utils import (
get_stencil_model_id,
update_lora_weight,
)
from shark.shark_downloader import download_public_file
from shark.shark_inference import SharkInference
# These shapes are parameter dependent.
@@ -55,6 +57,10 @@ def replace_shape_str(shape, max_len, width, height, batch_size):
new_shape.append(math.ceil(height / div_val))
elif "width" in shape[i]:
new_shape.append(math.ceil(width / div_val))
elif "+" in shape[i]:
# Currently this case only hits for SDXL. So, in case any other
# case requires this operator, change this.
new_shape.append(height + width)
else:
new_shape.append(shape[i])
return new_shape
@@ -67,6 +73,70 @@ def check_compilation(model, model_name):
)
def shark_compile_after_ir(
module_name,
device,
vmfb_path,
precision,
ir_path=None,
):
if ir_path:
print(f"[DEBUG] mlir found at {ir_path.absolute()}")
module = SharkInference(
mlir_module=ir_path,
device=device,
mlir_dialect="tm_tensor",
)
print(f"Will get extra flag for {module_name} and precision = {precision}")
path = module.save_module(
vmfb_path.parent.absolute(),
vmfb_path.stem,
extra_args=get_opt_flags(module_name, precision=precision),
)
print(f"Saved {module_name} vmfb at {path}")
module.load_module(path)
return module
def process_vmfb_ir_sdxl(extended_model_name, model_name, device, precision):
name_split = extended_model_name.split("_")
if "vae" in model_name:
name_split[5] = "fp32"
extended_model_name_for_vmfb = "_".join(name_split)
extended_model_name_for_mlir = "_".join(name_split[:-1])
vmfb_path = Path(extended_model_name_for_vmfb + ".vmfb")
if "vulkan" in device:
_device = args.iree_vulkan_target_triple
_device = _device.replace("-", "_")
vmfb_path = Path(extended_model_name_for_vmfb + f"_{_device}.vmfb")
if vmfb_path.exists():
shark_module = SharkInference(
None,
device=device,
mlir_dialect="tm_tensor",
)
print(f"loading existing vmfb from: {vmfb_path}")
shark_module.load_module(vmfb_path, extra_args=[])
return shark_module, None
mlir_path = Path(extended_model_name_for_mlir + ".mlir")
if not mlir_path.exists():
print(f"Looking into gs://shark_tank/SDXL/mlir/{mlir_path.name}")
download_public_file(
f"gs://shark_tank/SDXL/mlir/{mlir_path.name}",
mlir_path.absolute(),
single_file=True,
)
if mlir_path.exists():
return (
shark_compile_after_ir(
model_name, device, vmfb_path, precision, mlir_path
),
None,
)
return None, None
class SharkifyStableDiffusionModel:
def __init__(
self,
@@ -86,13 +156,15 @@ class SharkifyStableDiffusionModel:
generate_vmfb: bool = True,
is_inpaint: bool = False,
is_upscaler: bool = False,
use_stencil: str = None,
is_sdxl: bool = False,
stencils: list[str] = [],
use_lora: str = "",
use_quantize: str = None,
return_mlir: bool = False,
):
self.check_params(max_len, width, height)
self.max_len = max_len
self.is_sdxl = is_sdxl
self.height = height // 8
self.width = width // 8
self.batch_size = batch_size
@@ -144,7 +216,7 @@ class SharkifyStableDiffusionModel:
self.low_cpu_mem_usage = low_cpu_mem_usage
self.is_inpaint = is_inpaint
self.is_upscaler = is_upscaler
self.use_stencil = get_stencil_model_id(use_stencil)
self.stencils = [get_stencil_model_id(x) for x in stencils]
if use_lora != "":
self.model_name = self.model_name + "_" + get_path_stem(use_lora)
self.use_lora = use_lora
@@ -175,6 +247,7 @@ class SharkifyStableDiffusionModel:
model_name = {}
sub_model_list = [
"clip",
"clip2",
"unet",
"unet512",
"stencil_unet",
@@ -195,8 +268,9 @@ class SharkifyStableDiffusionModel:
)
if self.base_vae:
sub_model = "base_vae"
if "stencil_adaptor" == model and self.use_stencil is not None:
model_config = model_config + get_path_stem(self.use_stencil)
# TODO: Fix this
# if "stencil_adaptor" == model and self.use_stencil is not None:
# model_config = model_config + get_path_stem(self.use_stencil)
model_name[model] = get_extended_name(sub_model + model_config)
index += 1
return model_name
@@ -342,6 +416,76 @@ class SharkifyStableDiffusionModel:
)
return shark_vae, vae_mlir
def get_vae_sdxl(self):
# TODO: Remove this after convergence with shark_tank. This should just be part of
# opt_params.py.
shark_module_or_none = process_vmfb_ir_sdxl(
self.model_name["vae"], "vae", args.device, self.precision
)
if shark_module_or_none[0]:
return shark_module_or_none
class VaeModel(torch.nn.Module):
def __init__(
self,
model_id=self.model_id,
base_vae=self.base_vae,
custom_vae=self.custom_vae,
low_cpu_mem_usage=False,
):
super().__init__()
self.vae = None
if custom_vae == "":
self.vae = AutoencoderKL.from_pretrained(
model_id,
subfolder="vae",
low_cpu_mem_usage=low_cpu_mem_usage,
)
elif not isinstance(custom_vae, dict):
self.vae = AutoencoderKL.from_pretrained(
custom_vae,
subfolder="vae",
low_cpu_mem_usage=low_cpu_mem_usage,
)
else:
self.vae = AutoencoderKL.from_pretrained(
model_id,
subfolder="vae",
low_cpu_mem_usage=low_cpu_mem_usage,
)
self.vae.load_state_dict(custom_vae)
def forward(self, latents):
image = self.vae.decode(latents / 0.13025, return_dict=False)[
0
]
return image
vae = VaeModel(low_cpu_mem_usage=self.low_cpu_mem_usage)
inputs = tuple(self.inputs["vae"])
# Make sure the VAE is in float32 mode, as it overflows in float16 as per SDXL
# pipeline.
is_f16 = False
save_dir = os.path.join(self.sharktank_dir, self.model_name["vae"])
if self.debug:
os.makedirs(save_dir, exist_ok=True)
shark_vae, vae_mlir = compile_through_fx(
vae,
inputs,
is_f16=is_f16,
use_tuned=self.use_tuned,
extended_model_name=self.model_name["vae"],
debug=self.debug,
generate_vmfb=self.generate_vmfb,
save_dir=save_dir,
extra_args=get_opt_flags("vae", precision=self.precision),
base_model_id=self.base_model_id,
model_name="vae",
precision=self.precision,
return_mlir=self.return_mlir,
)
return shark_vae, vae_mlir
def get_controlled_unet(self, use_large=False):
class ControlledUnetModel(torch.nn.Module):
def __init__(
@@ -380,25 +524,54 @@ class SharkifyStableDiffusionModel:
control11,
control12,
control13,
scale1,
scale2,
scale3,
scale4,
scale5,
scale6,
scale7,
scale8,
scale9,
scale10,
scale11,
scale12,
scale13,
):
# TODO: Average pooling
db_res_samples = [
control1,
control2,
control3,
control4,
control5,
control6,
control7,
control8,
control9,
control10,
control11,
control12,
]
# expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
db_res_samples = tuple(
[
control1,
control2,
control3,
control4,
control5,
control6,
control7,
control8,
control9,
control10,
control11,
control12,
control1 * scale1,
control2 * scale2,
control3 * scale3,
control4 * scale4,
control5 * scale5,
control6 * scale6,
control7 * scale7,
control8 * scale8,
control9 * scale9,
control10 * scale10,
control11 * scale11,
control12 * scale12,
]
)
mb_res_samples = control13
mb_res_samples = control13 * scale13
latents = torch.cat([latent] * 2)
unet_out = self.unet.forward(
latents,
@@ -446,6 +619,19 @@ class SharkifyStableDiffusionModel:
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
]
shark_controlled_unet, controlled_unet_mlir = compile_through_fx(
unet,
@@ -462,11 +648,11 @@ class SharkifyStableDiffusionModel:
)
return shark_controlled_unet, controlled_unet_mlir
def get_control_net(self, use_large=False):
def get_control_net(self, stencil_id, use_large=False):
stencil_id = get_stencil_model_id(stencil_id)
class StencilControlNetModel(torch.nn.Module):
def __init__(
self, model_id=self.use_stencil, low_cpu_mem_usage=False
):
def __init__(self, model_id=stencil_id, low_cpu_mem_usage=False):
super().__init__()
self.cnet = ControlNetModel.from_pretrained(
model_id,
@@ -481,6 +667,19 @@ class SharkifyStableDiffusionModel:
timestep,
text_embedding,
stencil_image_input,
acc1,
acc2,
acc3,
acc4,
acc5,
acc6,
acc7,
acc8,
acc9,
acc10,
acc11,
acc12,
acc13,
):
# expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
# TODO: guidance NOT NEEDED change in `get_input_info` later
@@ -502,6 +701,20 @@ class SharkifyStableDiffusionModel:
)
return tuple(
list(down_block_res_samples) + [mid_block_res_sample]
) + (
acc1 + down_block_res_samples[0],
acc2 + down_block_res_samples[1],
acc3 + down_block_res_samples[2],
acc4 + down_block_res_samples[3],
acc5 + down_block_res_samples[4],
acc6 + down_block_res_samples[5],
acc7 + down_block_res_samples[6],
acc8 + down_block_res_samples[7],
acc9 + down_block_res_samples[8],
acc10 + down_block_res_samples[9],
acc11 + down_block_res_samples[10],
acc12 + down_block_res_samples[11],
acc13 + mid_block_res_sample,
)
scnet = StencilControlNetModel(
@@ -517,7 +730,7 @@ class SharkifyStableDiffusionModel:
inputs[0],
inputs[1],
torch.nn.functional.pad(inputs[2], pad),
inputs[3],
*inputs[3:],
)
save_dir = os.path.join(
self.sharktank_dir, self.model_name["stencil_adaptor_512"]
@@ -526,7 +739,7 @@ class SharkifyStableDiffusionModel:
save_dir = os.path.join(
self.sharktank_dir, self.model_name["stencil_adaptor"]
)
input_mask = [True, True, True, True]
input_mask = [True, True, True, True] + ([True] * 13)
model_name = "stencil_adaptor" if use_large else "stencil_adaptor_512"
shark_cnet, cnet_mlir = compile_through_fx(
scnet,
@@ -688,6 +901,93 @@ class SharkifyStableDiffusionModel:
)
return shark_unet, unet_mlir
def get_unet_sdxl(self):
# TODO: Remove this after convergence with shark_tank. This should just be part of
# opt_params.py.
shark_module_or_none = process_vmfb_ir_sdxl(
self.model_name["unet"], "unet", args.device, self.precision
)
if shark_module_or_none[0]:
return shark_module_or_none
class UnetModel(torch.nn.Module):
def __init__(
self,
model_id=self.model_id,
low_cpu_mem_usage=False,
):
super().__init__()
self.unet = UNet2DConditionModel.from_pretrained(
model_id,
subfolder="unet",
low_cpu_mem_usage=low_cpu_mem_usage,
)
if (
args.attention_slicing is not None
and args.attention_slicing != "none"
):
if args.attention_slicing.isdigit():
self.unet.set_attention_slice(
int(args.attention_slicing)
)
else:
self.unet.set_attention_slice(args.attention_slicing)
def forward(
self,
latent,
timestep,
prompt_embeds,
text_embeds,
time_ids,
guidance_scale,
):
added_cond_kwargs = {
"text_embeds": text_embeds,
"time_ids": time_ids,
}
noise_pred = self.unet.forward(
latent,
timestep,
encoder_hidden_states=prompt_embeds,
cross_attention_kwargs=None,
added_cond_kwargs=added_cond_kwargs,
return_dict=False,
)[0]
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (
noise_pred_text - noise_pred_uncond
)
return noise_pred
unet = UnetModel(low_cpu_mem_usage=self.low_cpu_mem_usage)
is_f16 = True if self.precision == "fp16" else False
inputs = tuple(self.inputs["unet"])
save_dir = os.path.join(self.sharktank_dir, self.model_name["unet"])
input_mask = [True, True, True, True, True, True]
if self.debug:
os.makedirs(
save_dir,
exist_ok=True,
)
shark_unet, unet_mlir = compile_through_fx(
unet,
inputs,
extended_model_name=self.model_name["unet"],
is_f16=is_f16,
f16_input_mask=input_mask,
use_tuned=self.use_tuned,
debug=self.debug,
generate_vmfb=self.generate_vmfb,
save_dir=save_dir,
extra_args=get_opt_flags("unet", precision=self.precision),
base_model_id=self.base_model_id,
model_name="unet",
precision=self.precision,
return_mlir=self.return_mlir,
)
return shark_unet, unet_mlir
def get_clip(self):
class CLIPText(torch.nn.Module):
def __init__(
@@ -735,6 +1035,78 @@ class SharkifyStableDiffusionModel:
)
return shark_clip, clip_mlir
def get_clip_sdxl(self, clip_index=1):
if clip_index == 1:
extended_model_name = self.model_name["clip"]
model_name = "clip"
else:
extended_model_name = self.model_name["clip2"]
model_name = "clip2"
# TODO: Remove this after convergence with shark_tank. This should just be part of
# opt_params.py.
shark_module_or_none = process_vmfb_ir_sdxl(
extended_model_name, f"clip", args.device, self.precision
)
if shark_module_or_none[0]:
return shark_module_or_none
class CLIPText(torch.nn.Module):
def __init__(
self,
model_id=self.model_id,
low_cpu_mem_usage=False,
clip_index=1,
):
super().__init__()
if clip_index == 1:
self.text_encoder = CLIPTextModel.from_pretrained(
model_id,
subfolder="text_encoder",
low_cpu_mem_usage=low_cpu_mem_usage,
)
else:
self.text_encoder = (
CLIPTextModelWithProjection.from_pretrained(
model_id,
subfolder="text_encoder_2",
low_cpu_mem_usage=low_cpu_mem_usage,
)
)
def forward(self, input):
prompt_embeds = self.text_encoder(
input,
output_hidden_states=True,
)
# We are only ALWAYS interested in the pooled output of the final text encoder
pooled_prompt_embeds = prompt_embeds[0]
prompt_embeds = prompt_embeds.hidden_states[-2]
return prompt_embeds, pooled_prompt_embeds
clip_model = CLIPText(
low_cpu_mem_usage=self.low_cpu_mem_usage, clip_index=clip_index
)
save_dir = os.path.join(self.sharktank_dir, extended_model_name)
if self.debug:
os.makedirs(
save_dir,
exist_ok=True,
)
shark_clip, clip_mlir = compile_through_fx(
clip_model,
tuple(self.inputs["clip"]),
extended_model_name=extended_model_name,
debug=self.debug,
generate_vmfb=self.generate_vmfb,
save_dir=save_dir,
extra_args=get_opt_flags("clip", precision="fp32"),
base_model_id=self.base_model_id,
model_name="clip",
precision=self.precision,
return_mlir=self.return_mlir,
)
return shark_clip, clip_mlir
def process_custom_vae(self):
custom_vae = self.custom_vae.lower()
if not custom_vae.endswith((".ckpt", ".safetensors")):
@@ -767,7 +1139,9 @@ class SharkifyStableDiffusionModel:
}
return vae_dict
def compile_unet_variants(self, model, use_large=False):
def compile_unet_variants(self, model, use_large=False, base_model=""):
if self.is_sdxl:
return self.get_unet_sdxl()
if model == "unet":
if self.is_upscaler:
return self.get_unet_upscaler(use_large=use_large)
@@ -809,9 +1183,28 @@ class SharkifyStableDiffusionModel:
except Exception as e:
sys.exit(e)
def sdxl_clip(self):
try:
self.inputs["clip"] = self.get_input_info_for(
base_models["sdxl_clip"]
)
compiled_clip, clip_mlir = self.get_clip_sdxl(clip_index=1)
compiled_clip2, clip_mlir2 = self.get_clip_sdxl(clip_index=2)
check_compilation(compiled_clip, "Clip")
check_compilation(compiled_clip, "Clip2")
if self.return_mlir:
return clip_mlir, clip_mlir2
return compiled_clip, compiled_clip2
except Exception as e:
sys.exit(e)
def unet(self, use_large=False):
try:
model = "stencil_unet" if self.use_stencil is not None else "unet"
stencil_count = 0
for stencil in self.stencils:
stencil_count += 1
model = "stencil_unet" if stencil_count > 0 else "unet"
compiled_unet = None
unet_inputs = base_models[model]
@@ -820,7 +1213,7 @@ class SharkifyStableDiffusionModel:
unet_inputs[self.base_model_id]
)
compiled_unet, unet_mlir = self.compile_unet_variants(
model, use_large=use_large
model, use_large=use_large, base_model=self.base_model_id
)
else:
for model_id in unet_inputs:
@@ -831,7 +1224,7 @@ class SharkifyStableDiffusionModel:
try:
compiled_unet, unet_mlir = self.compile_unet_variants(
model, use_large=use_large
model, use_large=use_large, base_model=model_id
)
except Exception as e:
print(e)
@@ -870,7 +1263,10 @@ class SharkifyStableDiffusionModel:
is_base_vae = self.base_vae
if self.is_upscaler:
self.base_vae = True
compiled_vae, vae_mlir = self.get_vae()
if self.is_sdxl:
compiled_vae, vae_mlir = self.get_vae_sdxl()
else:
compiled_vae, vae_mlir = self.get_vae()
self.base_vae = is_base_vae
check_compilation(compiled_vae, "Vae")
@@ -880,13 +1276,13 @@ class SharkifyStableDiffusionModel:
except Exception as e:
sys.exit(e)
def controlnet(self, use_large=False):
def controlnet(self, stencil_id, use_large=False):
try:
self.inputs["stencil_adaptor"] = self.get_input_info_for(
base_models["stencil_adaptor"]
)
compiled_stencil_adaptor, controlnet_mlir = self.get_control_net(
use_large=use_large
stencil_id, use_large=use_large
)
check_compilation(compiled_stencil_adaptor, "Stencil")

View File

@@ -123,8 +123,8 @@ def get_clip():
return get_shark_model(bucket, model_name, iree_flags)
def get_tokenizer():
def get_tokenizer(subfolder="tokenizer"):
tokenizer = CLIPTokenizer.from_pretrained(
args.hf_model_id, subfolder="tokenizer"
args.hf_model_id, subfolder=subfolder
)
return tokenizer

View File

@@ -1,6 +1,9 @@
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_txt2img import (
Text2ImagePipeline,
)
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_txt2img_sdxl import (
Text2ImageSDXLPipeline,
)
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_img2img import (
Image2ImagePipeline,
)

View File

@@ -158,7 +158,6 @@ class Image2ImagePipeline(StableDiffusionPipeline):
use_base_vae,
cpu_scheduling,
max_embeddings_multiples,
use_stencil,
resample_type,
):
# prompts and negative prompts must be a list.

View File

@@ -55,28 +55,47 @@ class StencilPipeline(StableDiffusionPipeline):
import_mlir: bool,
use_lora: str,
ondemand: bool,
controlnet_names: list[str],
):
super().__init__(scheduler, sd_model, import_mlir, use_lora, ondemand)
self.controlnet = None
self.controlnet_512 = None
self.controlnet = [None] * len(controlnet_names)
self.controlnet_512 = [None] * len(controlnet_names)
self.controlnet_id = [str] * len(controlnet_names)
self.controlnet_512_id = [str] * len(controlnet_names)
self.controlnet_names = controlnet_names
def load_controlnet(self):
if self.controlnet is not None:
def load_controlnet(self, index, model_name):
if model_name is None:
return
self.controlnet = self.sd_model.controlnet()
def unload_controlnet(self):
del self.controlnet
self.controlnet = None
def load_controlnet_512(self):
if self.controlnet_512 is not None:
if (
self.controlnet[index] is not None
and self.controlnet_id[index] is not None
and self.controlnet_id[index] == model_name
):
return
self.controlnet_512 = self.sd_model.controlnet(use_large=True)
self.controlnet_id[index] = model_name
self.controlnet[index] = self.sd_model.controlnet(model_name)
def unload_controlnet_512(self):
del self.controlnet_512
self.controlnet_512 = None
def unload_controlnet(self, index):
del self.controlnet[index]
self.controlnet_id[index] = None
self.controlnet[index] = None
def load_controlnet_512(self, index, model_name):
if (
self.controlnet_512[index] is not None
and self.controlnet_512_id[index] == model_name
):
return
self.controlnet_512_id[index] = model_name
self.controlnet_512[index] = self.sd_model.controlnet(
model_name, use_large=True
)
def unload_controlnet_512(self, index):
del self.controlnet_512[index]
self.controlnet_512_id[index] = None
self.controlnet_512[index] = None
def prepare_latents(
self,
@@ -111,8 +130,9 @@ class StencilPipeline(StableDiffusionPipeline):
total_timesteps,
dtype,
cpu_scheduling,
controlnet_hint=None,
stencil_hints=[None],
controlnet_conditioning_scale: float = 1.0,
control_mode="Balanced", # Prompt, Balanced, or Controlnet
mask=None,
masked_image_latents=None,
return_all_latents=False,
@@ -121,12 +141,18 @@ class StencilPipeline(StableDiffusionPipeline):
latent_history = [latents]
text_embeddings = torch.from_numpy(text_embeddings).to(dtype)
text_embeddings_numpy = text_embeddings.detach().numpy()
assert control_mode in ["Prompt", "Balanced", "Controlnet"]
if text_embeddings.shape[1] <= self.model_max_length:
self.load_unet()
self.load_controlnet()
else:
self.load_unet_512()
self.load_controlnet_512()
for i, name in enumerate(self.controlnet_names):
if text_embeddings.shape[1] <= self.model_max_length:
self.load_controlnet(i, name)
else:
self.load_controlnet_512(i, name)
for i, t in tqdm(enumerate(total_timesteps)):
step_start_time = time.time()
timestep = torch.tensor([t]).to(dtype)
@@ -149,33 +175,93 @@ class StencilPipeline(StableDiffusionPipeline):
).to(dtype)
else:
latent_model_input_1 = latent_model_input
if text_embeddings.shape[1] <= self.model_max_length:
control = self.controlnet(
"forward",
(
latent_model_input_1,
timestep,
text_embeddings,
controlnet_hint,
),
send_to_host=False,
)
else:
control = self.controlnet_512(
"forward",
(
latent_model_input_1,
timestep,
text_embeddings,
controlnet_hint,
),
send_to_host=False,
)
# Multicontrolnet
width = latent_model_input_1.shape[2]
height = latent_model_input_1.shape[3]
dtype = latent_model_input_1.dtype
control_acc = (
[torch.zeros((2, 320, height, width), dtype=dtype)] * 3
+ [
torch.zeros(
(2, 320, int(height / 2), int(width / 2)), dtype=dtype
)
]
+ [
torch.zeros(
(2, 640, int(height / 2), int(width / 2)), dtype=dtype
)
]
* 2
+ [
torch.zeros(
(2, 640, int(height / 4), int(width / 4)), dtype=dtype
)
]
+ [
torch.zeros(
(2, 1280, int(height / 4), int(width / 4)), dtype=dtype
)
]
* 2
+ [
torch.zeros(
(2, 1280, int(height / 8), int(width / 8)), dtype=dtype
)
]
* 4
)
for i, controlnet_hint in enumerate(stencil_hints):
if controlnet_hint is None:
continue
if text_embeddings.shape[1] <= self.model_max_length:
control = self.controlnet[i](
"forward",
(
latent_model_input_1,
timestep,
text_embeddings,
controlnet_hint,
*control_acc,
),
send_to_host=False,
)
else:
control = self.controlnet_512[i](
"forward",
(
latent_model_input_1,
timestep,
text_embeddings,
controlnet_hint,
*control_acc,
),
send_to_host=False,
)
control_acc = control[13:]
control = control[:13]
timestep = timestep.detach().numpy()
# Profiling Unet.
profile_device = start_profiling(file_path="unet.rdc")
# TODO: Pass `control` as it is to Unet. Same as TODO mentioned in model_wrappers.py.
dtype = latents.dtype
if control_mode == "Balanced":
control_scale = [
torch.tensor(1.0, dtype=dtype) for _ in range(len(control))
]
elif control_mode == "Prompt":
control_scale = [
torch.tensor(0.825**x, dtype=dtype)
for x in range(len(control))
]
elif control_mode == "Controlnet":
control_scale = [
torch.tensor(float(guidance_scale), dtype=dtype)
for _ in range(len(control))
]
if text_embeddings.shape[1] <= self.model_max_length:
noise_pred = self.unet(
"forward",
@@ -197,6 +283,19 @@ class StencilPipeline(StableDiffusionPipeline):
control[10],
control[11],
control[12],
control_scale[0],
control_scale[1],
control_scale[2],
control_scale[3],
control_scale[4],
control_scale[5],
control_scale[6],
control_scale[7],
control_scale[8],
control_scale[9],
control_scale[10],
control_scale[11],
control_scale[12],
),
send_to_host=False,
)
@@ -222,6 +321,19 @@ class StencilPipeline(StableDiffusionPipeline):
control[10],
control[11],
control[12],
control_scale[0],
control_scale[1],
control_scale[2],
control_scale[3],
control_scale[4],
control_scale[5],
control_scale[6],
control_scale[7],
control_scale[8],
control_scale[9],
control_scale[10],
control_scale[11],
control_scale[12],
),
send_to_host=False,
)
@@ -245,8 +357,9 @@ class StencilPipeline(StableDiffusionPipeline):
if self.ondemand:
self.unload_unet()
self.unload_unet_512()
self.unload_controlnet()
self.unload_controlnet_512()
for i in range(len(self.controlnet_names)):
self.unload_controlnet(i)
self.unload_controlnet_512(i)
avg_step_time = step_time_sum / len(total_timesteps)
self.log += f"\nAverage step time: {avg_step_time}ms/it"
@@ -272,14 +385,30 @@ class StencilPipeline(StableDiffusionPipeline):
use_base_vae,
cpu_scheduling,
max_embeddings_multiples,
use_stencil,
stencils,
stencil_images,
resample_type,
control_mode,
):
# Control Embedding check & conversion
# TODO: 1. Change `num_images_per_prompt`.
controlnet_hint = controlnet_hint_conversion(
image, use_stencil, height, width, dtype, num_images_per_prompt=1
)
# controlnet_hint = controlnet_hint_conversion(
# image, use_stencil, height, width, dtype, num_images_per_prompt=1
# )
stencil_hints = []
for i, stencil in enumerate(stencils):
image = stencil_images[i]
stencil_hints.append(
controlnet_hint_conversion(
image,
stencil,
height,
width,
dtype,
num_images_per_prompt=1,
)
)
# prompts and negative prompts must be a list.
if isinstance(prompts, str):
prompts = [prompts]
@@ -327,7 +456,8 @@ class StencilPipeline(StableDiffusionPipeline):
total_timesteps=final_timesteps,
dtype=dtype,
cpu_scheduling=cpu_scheduling,
controlnet_hint=controlnet_hint,
control_mode=control_mode,
stencil_hints=stencil_hints,
)
# Img latents -> PIL images

View File

@@ -0,0 +1,214 @@
import torch
import numpy as np
from random import randint
from typing import Union
from diffusers import (
DDIMScheduler,
PNDMScheduler,
LMSDiscreteScheduler,
KDPM2DiscreteScheduler,
EulerDiscreteScheduler,
EulerAncestralDiscreteScheduler,
DPMSolverMultistepScheduler,
DEISMultistepScheduler,
DDPMScheduler,
DPMSolverSinglestepScheduler,
KDPM2AncestralDiscreteScheduler,
HeunDiscreteScheduler,
)
from apps.stable_diffusion.src.schedulers import SharkEulerDiscreteScheduler
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_utils import (
StableDiffusionPipeline,
)
from apps.stable_diffusion.src.models import SharkifyStableDiffusionModel
from transformers.utils import logging
logger = logging.get_logger(__name__)
class Text2ImageSDXLPipeline(StableDiffusionPipeline):
def __init__(
self,
scheduler: Union[
DDIMScheduler,
PNDMScheduler,
LMSDiscreteScheduler,
KDPM2DiscreteScheduler,
EulerDiscreteScheduler,
EulerAncestralDiscreteScheduler,
DPMSolverMultistepScheduler,
SharkEulerDiscreteScheduler,
DEISMultistepScheduler,
DDPMScheduler,
DPMSolverSinglestepScheduler,
KDPM2AncestralDiscreteScheduler,
HeunDiscreteScheduler,
],
sd_model: SharkifyStableDiffusionModel,
import_mlir: bool,
use_lora: str,
ondemand: bool,
):
super().__init__(scheduler, sd_model, import_mlir, use_lora, ondemand)
def prepare_latents(
self,
batch_size,
height,
width,
generator,
num_inference_steps,
dtype,
):
latents = torch.randn(
(
batch_size,
4,
height // 8,
width // 8,
),
generator=generator,
dtype=torch.float32,
).to(dtype)
self.scheduler.set_timesteps(num_inference_steps)
self.scheduler.is_scale_input_called = True
latents = latents * self.scheduler.init_noise_sigma
return latents
def _get_add_time_ids(
self, original_size, crops_coords_top_left, target_size, dtype
):
add_time_ids = list(
original_size + crops_coords_top_left + target_size
)
# self.unet.config.addition_time_embed_dim IS 256.
# self.text_encoder_2.config.projection_dim IS 1280.
passed_add_embed_dim = 256 * len(add_time_ids) + 1280
expected_add_embed_dim = 2816
# self.unet.add_embedding.linear_1.in_features IS 2816.
if expected_add_embed_dim != passed_add_embed_dim:
raise ValueError(
f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`."
)
add_time_ids = torch.tensor([add_time_ids], dtype=dtype)
return add_time_ids
def generate_images(
self,
prompts,
neg_prompts,
batch_size,
height,
width,
num_inference_steps,
guidance_scale,
seed,
max_length,
dtype,
use_base_vae,
cpu_scheduling,
max_embeddings_multiples,
):
# prompts and negative prompts must be a list.
if isinstance(prompts, str):
prompts = [prompts]
if isinstance(neg_prompts, str):
neg_prompts = [neg_prompts]
prompts = prompts * batch_size
neg_prompts = neg_prompts * batch_size
# seed generator to create the inital latent noise. Also handle out of range seeds.
# TODO: Wouldn't it be preferable to just report an error instead of modifying the seed on the fly?
uint32_info = np.iinfo(np.uint32)
uint32_min, uint32_max = uint32_info.min, uint32_info.max
if seed < uint32_min or seed >= uint32_max:
seed = randint(uint32_min, uint32_max)
generator = torch.manual_seed(seed)
# Get initial latents.
init_latents = self.prepare_latents(
batch_size=batch_size,
height=height,
width=width,
generator=generator,
num_inference_steps=num_inference_steps,
dtype=dtype,
)
# Get text embeddings.
(
prompt_embeds,
negative_prompt_embeds,
pooled_prompt_embeds,
negative_pooled_prompt_embeds,
) = self.encode_prompt_sdxl(
prompt=prompts,
num_images_per_prompt=1,
do_classifier_free_guidance=True,
negative_prompt=neg_prompts,
)
# Prepare timesteps.
self.scheduler.set_timesteps(num_inference_steps)
timesteps = self.scheduler.timesteps
# Prepare added time ids & embeddings.
original_size = (height, width)
target_size = (height, width)
crops_coords_top_left = (0, 0)
add_text_embeds = pooled_prompt_embeds
add_time_ids = self._get_add_time_ids(
original_size,
crops_coords_top_left,
target_size,
dtype=prompt_embeds.dtype,
)
prompt_embeds = torch.cat(
[negative_prompt_embeds, prompt_embeds], dim=0
)
add_text_embeds = torch.cat(
[negative_pooled_prompt_embeds, add_text_embeds], dim=0
)
add_time_ids = torch.cat([add_time_ids, add_time_ids], dim=0)
prompt_embeds = prompt_embeds
add_text_embeds = add_text_embeds.to(dtype)
add_time_ids = add_time_ids.repeat(batch_size * 1, 1)
# guidance scale as a float32 tensor.
guidance_scale = torch.tensor(guidance_scale).to(dtype)
prompt_embeds = prompt_embeds.to(dtype)
add_time_ids = add_time_ids.to(dtype)
# Get Image latents.
latents = self.produce_img_latents_sdxl(
init_latents,
timesteps,
add_text_embeds,
add_time_ids,
prompt_embeds,
cpu_scheduling,
guidance_scale,
dtype,
)
# Img latents -> PIL images.
all_imgs = []
self.load_vae()
# imgs = self.decode_latents_sdxl(None)
# all_imgs.extend(imgs)
for i in range(0, latents.shape[0], batch_size):
imgs = self.decode_latents_sdxl(latents[i : i + batch_size])
all_imgs.extend(imgs)
if self.ondemand:
self.unload_vae()
return all_imgs

View File

@@ -33,6 +33,8 @@ from apps.stable_diffusion.src.utils import (
end_profiling,
)
import sys
import gc
from typing import List, Optional
SD_STATE_IDLE = "idle"
SD_STATE_CANCEL = "cancel"
@@ -63,6 +65,7 @@ class StableDiffusionPipeline:
):
self.vae = None
self.text_encoder = None
self.text_encoder_2 = None
self.unet = None
self.unet_512 = None
self.model_max_length = 77
@@ -106,6 +109,34 @@ class StableDiffusionPipeline:
del self.text_encoder
self.text_encoder = None
def load_clip_sdxl(self):
if self.text_encoder and self.text_encoder_2:
return
if self.import_mlir or self.use_lora:
if not self.import_mlir:
print(
"Warning: LoRA provided but import_mlir not specified. "
"Importing MLIR anyways."
)
self.text_encoder, self.text_encoder_2 = self.sd_model.sdxl_clip()
else:
try:
# TODO: Fix this for SDXL
self.text_encoder = get_clip()
except Exception as e:
print(e)
print("download pipeline failed, falling back to import_mlir")
(
self.text_encoder,
self.text_encoder_2,
) = self.sd_model.sdxl_clip()
def unload_clip_sdxl(self):
del self.text_encoder, self.text_encoder_2
self.text_encoder = None
self.text_encoder_2 = None
def load_unet(self):
if self.unet is not None:
return
@@ -159,6 +190,179 @@ class StableDiffusionPipeline:
def unload_vae(self):
del self.vae
self.vae = None
gc.collect()
def encode_prompt_sdxl(
self,
prompt: str,
num_images_per_prompt: int = 1,
do_classifier_free_guidance: bool = True,
negative_prompt: Optional[str] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
):
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
# Define tokenizers and text encoders
self.tokenizer_2 = get_tokenizer("tokenizer_2")
self.load_clip_sdxl()
tokenizers = (
[self.tokenizer, self.tokenizer_2]
if self.tokenizer is not None
else [self.tokenizer_2]
)
text_encoders = (
[self.text_encoder, self.text_encoder_2]
if self.text_encoder is not None
else [self.text_encoder_2]
)
# textual inversion: procecss multi-vector tokens if necessary
prompt_embeds_list = []
prompts = [prompt, prompt]
for prompt, tokenizer, text_encoder in zip(
prompts, tokenizers, text_encoders
):
text_inputs = tokenizer(
prompt,
padding="max_length",
max_length=tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
untruncated_ids = tokenizer(
prompt, padding="longest", return_tensors="pt"
).input_ids
if untruncated_ids.shape[-1] >= text_input_ids.shape[
-1
] and not torch.equal(text_input_ids, untruncated_ids):
removed_text = tokenizer.batch_decode(
untruncated_ids[:, tokenizer.model_max_length - 1 : -1]
)
print(
"The following part of your input was truncated because CLIP can only handle sequences up to"
f" {tokenizer.model_max_length} tokens: {removed_text}"
)
text_encoder_output = text_encoder("forward", (text_input_ids,))
prompt_embeds = torch.from_numpy(text_encoder_output[0])
pooled_prompt_embeds = torch.from_numpy(text_encoder_output[1])
prompt_embeds_list.append(prompt_embeds)
prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
# get unconditional embeddings for classifier free guidance
zero_out_negative_prompt = (
negative_prompt is None
and self.config.force_zeros_for_empty_prompt
)
if (
do_classifier_free_guidance
and negative_prompt_embeds is None
and zero_out_negative_prompt
):
negative_prompt_embeds = torch.zeros_like(prompt_embeds)
negative_pooled_prompt_embeds = torch.zeros_like(
pooled_prompt_embeds
)
elif do_classifier_free_guidance and negative_prompt_embeds is None:
negative_prompt = negative_prompt or ""
negative_prompt_2 = negative_prompt
uncond_tokens: List[str]
if prompt is not None and type(prompt) is not type(
negative_prompt
):
raise TypeError(
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
f" {type(prompt)}."
)
elif isinstance(negative_prompt, str):
uncond_tokens = [negative_prompt, negative_prompt_2]
elif batch_size != len(negative_prompt):
raise ValueError(
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
" the batch size of `prompt`."
)
else:
uncond_tokens = [negative_prompt, negative_prompt_2]
negative_prompt_embeds_list = []
for negative_prompt, tokenizer, text_encoder in zip(
uncond_tokens, tokenizers, text_encoders
):
max_length = prompt_embeds.shape[1]
uncond_input = tokenizer(
negative_prompt,
padding="max_length",
max_length=max_length,
truncation=True,
return_tensors="pt",
)
text_encoder_output = text_encoder(
"forward", (uncond_input.input_ids,)
)
negative_prompt_embeds = torch.from_numpy(
text_encoder_output[0]
)
negative_pooled_prompt_embeds = torch.from_numpy(
text_encoder_output[1]
)
negative_prompt_embeds_list.append(negative_prompt_embeds)
negative_prompt_embeds = torch.concat(
negative_prompt_embeds_list, dim=-1
)
if self.ondemand:
self.unload_clip_sdxl()
gc.collect()
# TODO: Look into dtype for text_encoder_2!
prompt_embeds = prompt_embeds.to(dtype=torch.float32)
bs_embed, seq_len, _ = prompt_embeds.shape
# duplicate text embeddings for each generation per prompt, using mps friendly method
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(
bs_embed * num_images_per_prompt, seq_len, -1
)
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
seq_len = negative_prompt_embeds.shape[1]
negative_prompt_embeds = negative_prompt_embeds.to(dtype=torch.float32)
negative_prompt_embeds = negative_prompt_embeds.repeat(
1, num_images_per_prompt, 1
)
negative_prompt_embeds = negative_prompt_embeds.view(
batch_size * num_images_per_prompt, seq_len, -1
)
pooled_prompt_embeds = pooled_prompt_embeds.repeat(
1, num_images_per_prompt
).view(bs_embed * num_images_per_prompt, -1)
negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(
1, num_images_per_prompt
).view(bs_embed * num_images_per_prompt, -1)
return (
prompt_embeds,
negative_prompt_embeds,
pooled_prompt_embeds,
negative_pooled_prompt_embeds,
)
def encode_prompts(self, prompts, neg_prompts, max_length):
# Tokenize text and get embeddings
@@ -186,6 +390,7 @@ class StableDiffusionPipeline:
clip_inf_time = (time.time() - clip_inf_start) * 1000
if self.ondemand:
self.unload_clip()
gc.collect()
self.log += f"\nClip Inference time (ms) = {clip_inf_time:.3f}"
return text_embeddings
@@ -298,6 +503,8 @@ class StableDiffusionPipeline:
if self.ondemand:
self.unload_unet()
self.unload_unet_512()
gc.collect()
avg_step_time = step_time_sum / len(total_timesteps)
self.log += f"\nAverage step time: {avg_step_time}ms/it"
@@ -306,6 +513,72 @@ class StableDiffusionPipeline:
all_latents = torch.cat(latent_history, dim=0)
return all_latents
def produce_img_latents_sdxl(
self,
latents,
total_timesteps,
add_text_embeds,
add_time_ids,
prompt_embeds,
cpu_scheduling,
guidance_scale,
dtype,
):
# return None
self.status = SD_STATE_IDLE
step_time_sum = 0
extra_step_kwargs = {"generator": None}
self.load_unet()
for i, t in tqdm(enumerate(total_timesteps)):
step_start_time = time.time()
timestep = torch.tensor([t]).to(dtype).detach().numpy()
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2)
latent_model_input = self.scheduler.scale_model_input(
latent_model_input, t
).to(dtype)
noise_pred = self.unet(
"forward",
(
latent_model_input,
timestep,
prompt_embeds,
add_text_embeds,
add_time_ids,
guidance_scale,
),
send_to_host=False,
)
latents = self.scheduler.step(
noise_pred, t, latents, **extra_step_kwargs, return_dict=False
)[0]
step_time = (time.time() - step_start_time) * 1000
step_time_sum += step_time
if self.status == SD_STATE_CANCEL:
break
if self.ondemand:
self.unload_unet()
gc.collect()
avg_step_time = step_time_sum / len(total_timesteps)
self.log += f"\nAverage step time: {avg_step_time}ms/it"
return latents
def decode_latents_sdxl(self, latents):
latents = latents.to(torch.float32)
images = self.vae("forward", (latents,))
images = (torch.from_numpy(images) / 2 + 0.5).clamp(0, 1)
images = images.cpu().permute(0, 2, 3, 1).float().numpy()
images = (images * 255).round().astype("uint8")
pil_images = [Image.fromarray(image[:, :, :3]) for image in images]
return pil_images
@classmethod
def from_pretrained(
cls,
@@ -338,7 +611,8 @@ class StableDiffusionPipeline:
ondemand: bool,
low_cpu_mem_usage: bool = False,
debug: bool = False,
use_stencil: str = None,
stencils: list[str] = [],
# stencil_images: list[Image] = []
use_lora: str = "",
ddpm_scheduler: DDPMScheduler = None,
use_quantize=None,
@@ -355,6 +629,7 @@ class StableDiffusionPipeline:
"OutpaintPipeline",
]
is_upscaler = cls.__name__ in ["UpscalerPipeline"]
is_sdxl = cls.__name__ in ["Text2ImageSDXLPipeline"]
sd_model = SharkifyStableDiffusionModel(
model_id,
@@ -371,7 +646,8 @@ class StableDiffusionPipeline:
debug=debug,
is_inpaint=is_inpaint,
is_upscaler=is_upscaler,
use_stencil=use_stencil,
is_sdxl=is_sdxl,
stencils=stencils,
use_lora=use_lora,
use_quantize=use_quantize,
)
@@ -386,6 +662,10 @@ class StableDiffusionPipeline:
ondemand,
)
if cls.__name__ == "StencilPipeline":
return cls(
scheduler, sd_model, import_mlir, use_lora, ondemand, stencils
)
return cls(scheduler, sd_model, import_mlir, use_lora, ondemand)
# #####################################################
@@ -498,6 +778,7 @@ class StableDiffusionPipeline:
clip_inf_time = (time.time() - clip_inf_start) * 1000
if self.ondemand:
self.unload_clip()
gc.collect()
self.log += f"\nClip Inference time (ms) = {clip_inf_time:.3f}"
return text_embeddings.numpy()

View File

@@ -8,6 +8,15 @@
"dtype":"i64"
}
},
"sdxl_clip": {
"token" : {
"shape" : [
"1*batch_size",
"max_len"
],
"dtype":"i64"
}
},
"vae_encode": {
"image" : {
"shape" : [
@@ -179,6 +188,49 @@
"shape": [2],
"dtype": "i64"
}
},
"stabilityai/stable-diffusion-xl-base-1.0": {
"latents": {
"shape": [
"2*batch_size",
4,
"height",
"width"
],
"dtype": "f32"
},
"timesteps": {
"shape": [
1
],
"dtype": "f32"
},
"prompt_embeds": {
"shape": [
"2*batch_size",
"max_len",
2048
],
"dtype": "f32"
},
"text_embeds": {
"shape": [
"2*batch_size",
1280
],
"dtype": "f32"
},
"time_ids": {
"shape": [
"2*batch_size",
6
],
"dtype": "f32"
},
"guidance_scale": {
"shape": 1,
"dtype": "f32"
}
}
},
"stencil_adaptor": {
@@ -208,6 +260,58 @@
"controlnet_hint": {
"shape": [1, 3, "8*height", "8*width"],
"dtype": "f32"
},
"acc1": {
"shape": [2, 320, "height", "width"],
"dtype": "f32"
},
"acc2": {
"shape": [2, 320, "height", "width"],
"dtype": "f32"
},
"acc3": {
"shape": [2, 320, "height", "width"],
"dtype": "f32"
},
"acc4": {
"shape": [2, 320, "height/2", "width/2"],
"dtype": "f32"
},
"acc5": {
"shape": [2, 640, "height/2", "width/2"],
"dtype": "f32"
},
"acc6": {
"shape": [2, 640, "height/2", "width/2"],
"dtype": "f32"
},
"acc7": {
"shape": [2, 640, "height/4", "width/4"],
"dtype": "f32"
},
"acc8": {
"shape": [2, 1280, "height/4", "width/4"],
"dtype": "f32"
},
"acc9": {
"shape": [2, 1280, "height/4", "width/4"],
"dtype": "f32"
},
"acc10": {
"shape": [2, 1280, "height/8", "width/8"],
"dtype": "f32"
},
"acc11": {
"shape": [2, 1280, "height/8", "width/8"],
"dtype": "f32"
},
"acc12": {
"shape": [2, 1280, "height/8", "width/8"],
"dtype": "f32"
},
"acc13": {
"shape": [2, 1280, "height/8", "width/8"],
"dtype": "f32"
}
},
"stencil_unet": {
@@ -290,6 +394,58 @@
"control13": {
"shape": [2, 1280, "height/8", "width/8"],
"dtype": "f32"
},
"scale1": {
"shape": 1,
"dtype": "f32"
},
"scale2": {
"shape": 1,
"dtype": "f32"
},
"scale3": {
"shape": 1,
"dtype": "f32"
},
"scale4": {
"shape": 1,
"dtype": "f32"
},
"scale5": {
"shape": 1,
"dtype": "f32"
},
"scale6": {
"shape": 1,
"dtype": "f32"
},
"scale7": {
"shape": 1,
"dtype": "f32"
},
"scale8": {
"shape": 1,
"dtype": "f32"
},
"scale9": {
"shape": 1,
"dtype": "f32"
},
"scale10": {
"shape": 1,
"dtype": "f32"
},
"scale11": {
"shape": 1,
"dtype": "f32"
},
"scale12": {
"shape": 1,
"dtype": "f32"
},
"scale13": {
"shape": 1,
"dtype": "f32"
}
}
}

View File

@@ -59,24 +59,28 @@
"tuned": {
"fp16": {
"default_compilation_flags": [
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-preprocessing-pad-linalg-ops{pad-size=16}))"
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-preprocessing-pad-linalg-ops{pad-size=16}))",
"--iree-opt-data-tiling=False"
]
},
"fp32": {
"default_compilation_flags": [
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-preprocessing-pad-linalg-ops{pad-size=16}))"
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-preprocessing-pad-linalg-ops{pad-size=16}))",
"--iree-opt-data-tiling=False"
]
}
},
"untuned": {
"fp16": {
"default_compilation_flags": [
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-preprocessing-pad-linalg-ops{pad-size=16}))"
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-preprocessing-pad-linalg-ops{pad-size=16}))",
"--iree-opt-data-tiling=False"
]
},
"fp32": {
"default_compilation_flags": [
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-preprocessing-pad-linalg-ops{pad-size=16}))"
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-preprocessing-pad-linalg-ops{pad-size=16}))",
"--iree-opt-data-tiling=False"
]
}
}

View File

@@ -85,7 +85,7 @@ p.add_argument(
"--height",
type=int,
default=512,
choices=range(128, 769, 8),
choices=range(128, 1025, 8),
help="The height of the output image.",
)
@@ -93,7 +93,7 @@ p.add_argument(
"--width",
type=int,
default=512,
choices=range(128, 769, 8),
choices=range(128, 1025, 8),
help="The width of the output image.",
)
@@ -420,6 +420,13 @@ p.add_argument(
help="Enable the stencil feature.",
)
p.add_argument(
"--control_mode",
choices=["Prompt", "Balanced", "Controlnet"],
default="Balanced",
help="How Controlnet injection should be prioritized.",
)
p.add_argument(
"--use_lora",
type=str,
@@ -587,6 +594,13 @@ p.add_argument(
help="Controls constant folding in iree-compile for all SD models.",
)
p.add_argument(
"--data_tiling",
default=False,
action=argparse.BooleanOptionalAction,
help="Controls data tiling in iree-compile for all SD models.",
)
##############################################################################
# Web UI flags
##############################################################################

View File

@@ -1,6 +1,10 @@
import numpy as np
from PIL import Image
import torch
import os
from pathlib import Path
import torchvision
import time
from apps.stable_diffusion.src.utils.stencils import (
CannyDetector,
OpenposeDetector,
@@ -10,6 +14,33 @@ from apps.stable_diffusion.src.utils.stencils import (
stencil = {}
def save_img(img):
from apps.stable_diffusion.src.utils import (
get_generated_imgs_path,
get_generated_imgs_todays_subdir,
)
subdir = Path(
get_generated_imgs_path(), get_generated_imgs_todays_subdir()
)
os.makedirs(subdir, exist_ok=True)
if isinstance(img, Image.Image):
img.save(
os.path.join(
subdir, "controlnet_" + str(int(time.time())) + ".png"
)
)
elif isinstance(img, np.ndarray):
img = Image.fromarray(img)
img.save(os.path.join(subdir, str(int(time.time())) + ".png"))
else:
converter = torchvision.transforms.ToPILImage()
for i in img:
converter(i).save(
os.path.join(subdir, str(int(time.time())) + ".png")
)
def HWC3(x):
assert x.dtype == np.uint8
if x.ndim == 2:
@@ -161,6 +192,7 @@ def hint_canny(
detected_map = stencil["canny"](
input_image, low_threshold, high_threshold
)
save_img(detected_map)
detected_map = HWC3(detected_map)
return detected_map
@@ -176,6 +208,7 @@ def hint_openpose(
stencil["openpose"] = OpenposeDetector()
detected_map, _ = stencil["openpose"](input_image)
save_img(detected_map)
detected_map = HWC3(detected_map)
return detected_map
@@ -187,6 +220,7 @@ def hint_scribble(image: Image.Image):
detected_map = np.zeros_like(input_image, dtype=np.uint8)
detected_map[np.min(input_image, axis=2) < 127] = 255
save_img(detected_map)
return detected_map
@@ -199,5 +233,6 @@ def hint_zoedepth(image: Image.Image):
stencil["depth"] = ZoeDetector()
detected_map = stencil["depth"](input_image)
save_img(detected_map)
detected_map = HWC3(detected_map)
return detected_map

View File

@@ -118,7 +118,7 @@ def compile_through_fx(
is_f16=False,
f16_input_mask=None,
use_tuned=False,
save_dir=tempfile.gettempdir(),
save_dir="",
debug=False,
generate_vmfb=True,
extra_args=None,
@@ -541,6 +541,8 @@ def get_opt_flags(model, precision="fp16"):
iree_flags.append(
"--iree-codegen-linalg-max-constant-fold-elements=9223372036854775807"
)
if args.data_tiling == False:
iree_flags.append("--iree-opt-data-tiling=False")
if "default_compilation_flags" in opt_flags[model][is_tuned][precision]:
iree_flags += opt_flags[model][is_tuned][precision][
@@ -563,6 +565,10 @@ def get_opt_flags(model, precision="fp16"):
iree_flags += opt_flags[model][is_tuned][precision][
"specified_compilation_flags"
][device]
if "vae" not in model:
# Due to lack of support for multi-reduce, we always collapse reduction
# dims before dispatch formation right now.
iree_flags += ["--iree-flow-collapse-reduction-dims"]
return iree_flags

View File

@@ -75,11 +75,11 @@ if __name__ == "__main__":
# Setup to use shark_tmp for gradio's temporary image files and clear any
# existing temporary images there if they exist. Then we can import gradio.
# It has to be in this order or gradio ignores what we've set up.
from apps.stable_diffusion.web.utils.gradio_configs import (
config_gradio_tmp_imgs_folder,
from apps.stable_diffusion.web.utils.tmp_configs import (
config_tmp,
)
config_gradio_tmp_imgs_folder()
config_tmp()
import gradio as gr
# Create custom models folders if they don't exist
@@ -109,6 +109,12 @@ if __name__ == "__main__":
txt2img_sendto_inpaint,
txt2img_sendto_outpaint,
txt2img_sendto_upscaler,
# SDXL
txt2img_sdxl_inf,
txt2img_sdxl_web,
txt2img_sdxl_custom_model,
txt2img_sdxl_gallery,
txt2img_sdxl_status,
# h2ogpt_upload,
# h2ogpt_web,
img2img_web,
@@ -253,6 +259,8 @@ if __name__ == "__main__":
# h2ogpt_upload.render()
# with gr.TabItem(label="DocuChat(Experimental)", id=12):
# h2ogpt_web.render()
with gr.TabItem(label="Text-to-Image-SDXL (Experimental)", id=13):
txt2img_sdxl_web.render()
actual_port = app.usable_port()
if actual_port != args.server_port:

View File

@@ -10,6 +10,13 @@ from apps.stable_diffusion.web.ui.txt2img_ui import (
txt2img_sendto_outpaint,
txt2img_sendto_upscaler,
)
from apps.stable_diffusion.web.ui.txt2img_sdxl_ui import (
txt2img_sdxl_inf,
txt2img_sdxl_web,
txt2img_sdxl_custom_model,
txt2img_sdxl_gallery,
txt2img_sdxl_status,
)
from apps.stable_diffusion.web.ui.img2img_ui import (
img2img_inf,
img2img_web,

View File

@@ -0,0 +1,55 @@
from apps.stable_diffusion.web.ui.utils import (
HSLHue,
hsl_color,
get_lora_metadata,
)
# Answers HTML to show the most frequent tags used when a LoRA was trained,
# taken from the metadata of its .safetensors file.
def lora_changed(lora_file):
# tag frequency percentage, that gets maximum amount of the staring hue
TAG_COLOR_THRESHOLD = 0.55
# tag frequency percentage, above which a tag is displayed
TAG_DISPLAY_THRESHOLD = 0.65
# template for the html used to display a tag
TAG_HTML_TEMPLATE = '<span class="lora-tag" style="border: 1px solid {color};">{tag}</span>'
if lora_file == "None":
return ["<div><i>No LoRA selected</i></div>"]
elif not lora_file.lower().endswith(".safetensors"):
return [
"<div><i>Only metadata queries for .safetensors files are currently supported</i></div>"
]
else:
metadata = get_lora_metadata(lora_file)
if metadata:
frequencies = metadata["frequencies"]
return [
"".join(
[
f'<div class="lora-model">Trained against weights in: {metadata["model"]}</div>'
]
+ [
TAG_HTML_TEMPLATE.format(
color=hsl_color(
(tag[1] - TAG_COLOR_THRESHOLD)
/ (1 - TAG_COLOR_THRESHOLD),
start=HSLHue.RED,
end=HSLHue.GREEN,
),
tag=tag[0],
)
for tag in frequencies
if tag[1] > TAG_DISPLAY_THRESHOLD
],
)
]
elif metadata is None:
return [
"<div><i>This LoRA does not publish tag frequency metadata</i></div>"
]
else:
return [
"<div><i>This LoRA has empty tag frequency metadata, or we could not parse it</i></div>"
]

View File

@@ -105,6 +105,18 @@ body {
background-color: var(--background-fill-primary);
}
.generating.svelte-zlszon.svelte-zlszon {
border: none;
}
.generating {
border: none !important;
}
#chatbot {
height: 100% !important;
}
/* display in full width for desktop devices */
@media (min-width: 1536px)
{
@@ -246,10 +258,39 @@ footer {
background-color: var(--block-label-background-fill);
}
/* lora tag pills */
.lora-tags {
border: 1px solid var(--border-color-primary);
color: var(--block-info-text-color) !important;
padding: var(--block-padding);
}
.lora-tag {
display: inline-block;
height: 2em;
color: rgb(212 212 212) !important;
margin-right: 5pt;
margin-bottom: 5pt;
padding: 2pt 5pt;
border-radius: 5pt;
white-space: nowrap;
}
.lora-model {
margin-bottom: var(--spacing-lg);
color: var(--block-info-text-color) !important;
line-height: var(--line-sm);
}
/* output gallery tab */
.output_parameters_dataframe table.table {
/* works around a gradio bug that always shows scrollbars */
overflow: clip auto;
}
.output_parameters_dataframe tbody td {
font-size: small;
line-height: var(--line-xs)
line-height: var(--line-xs);
}
.output_icon_button {

View File

@@ -5,6 +5,7 @@ import gradio as gr
import PIL
from math import ceil
from PIL import Image
from apps.stable_diffusion.web.ui.utils import (
available_devices,
nodlogo_loc,
@@ -14,6 +15,7 @@ from apps.stable_diffusion.web.ui.utils import (
predefined_models,
cancel_sd,
)
from apps.stable_diffusion.web.ui.common_ui_events import lora_changed
from apps.stable_diffusion.src import (
args,
Image2ImagePipeline,
@@ -29,6 +31,10 @@ from apps.stable_diffusion.src.utils import (
get_generation_text_info,
resampler_list,
)
from apps.stable_diffusion.src.utils.stencils import (
CannyDetector,
OpenposeDetector,
)
from apps.stable_diffusion.web.utils.common_label_calc import status_label
import numpy as np
@@ -58,7 +64,6 @@ def img2img_inf(
precision: str,
device: str,
max_length: int,
use_stencil: str,
save_metadata_to_json: bool,
save_metadata_to_png: bool,
lora_weights: str,
@@ -66,6 +71,9 @@ def img2img_inf(
ondemand: bool,
repeatable_seeds: bool,
resample_type: str,
control_mode: str,
stencils: list,
images: list,
):
from apps.stable_diffusion.web.ui.utils import (
get_custom_model_pathfile,
@@ -87,11 +95,17 @@ def img2img_inf(
args.img_path = "not none"
args.ondemand = ondemand
for i, stencil in enumerate(stencils):
if images[i] is None and stencil is not None:
return None, "A stencil must have an Image input"
if images[i] is not None:
images[i] = images[i].convert("RGB")
if image_dict is None:
return None, "An Initial Image is required"
if use_stencil == "scribble":
image = image_dict["mask"].convert("RGB")
elif isinstance(image_dict, PIL.Image.Image):
# if use_stencil == "scribble":
# image = image_dict["mask"].convert("RGB")
if isinstance(image_dict, PIL.Image.Image):
image = image_dict.convert("RGB")
else:
image = image_dict["image"].convert("RGB")
@@ -121,12 +135,14 @@ def img2img_inf(
args.save_metadata_to_json = save_metadata_to_json
args.write_metadata_to_png = save_metadata_to_png
use_stencil = None if use_stencil == "None" else use_stencil
args.use_stencil = use_stencil
if use_stencil is not None:
stencil_count = 0
for stencil in stencils:
if stencil is not None:
stencil_count += 1
if stencil_count > 0:
args.scheduler = "DDIM"
args.hf_model_id = "runwayml/stable-diffusion-v1-5"
image, width, height = resize_stencil(image)
# image, width, height = resize_stencil(image)
elif "Shark" in args.scheduler:
print(
f"Shark schedulers are not supported. Switching to EulerDiscrete "
@@ -148,7 +164,7 @@ def img2img_inf(
width,
device,
use_lora=args.use_lora,
use_stencil=use_stencil,
stencils=stencils,
ondemand=ondemand,
)
if (
@@ -175,7 +191,7 @@ def img2img_inf(
global_obj.set_schedulers(get_schedulers(model_id))
scheduler_obj = global_obj.get_scheduler(args.scheduler)
if use_stencil is not None:
if stencil_count > 0:
args.use_tuned = False
global_obj.set_sd_obj(
StencilPipeline.from_pretrained(
@@ -192,7 +208,7 @@ def img2img_inf(
args.use_base_vae,
args.use_tuned,
low_cpu_mem_usage=args.low_cpu_mem_usage,
use_stencil=use_stencil,
stencils=stencils,
debug=args.import_debug if args.import_mlir else False,
use_lora=args.use_lora,
ondemand=args.ondemand,
@@ -249,8 +265,10 @@ def img2img_inf(
args.use_base_vae,
cpu_scheduling,
args.max_embeddings_multiples,
use_stencil=use_stencil,
stencils,
images,
resample_type=resample_type,
control_mode=control_mode,
)
total_time = time.time() - start_time
text_output = get_generation_text_info(
@@ -270,12 +288,17 @@ def img2img_inf(
generated_imgs.extend(out_imgs)
yield generated_imgs, text_output, status_label(
"Image-to-Image", current_batch + 1, batch_count, batch_size
)
), stencils, images
return generated_imgs, text_output, ""
return generated_imgs, text_output, "", stencils, images
with gr.Blocks(title="Image-to-Image") as img2img_web:
# Stencils
# TODO: Add more stencils here
STENCIL_COUNT = 2
stencils = gr.State([None] * STENCIL_COUNT)
images = gr.State([None] * STENCIL_COUNT)
with gr.Row(elem_id="ui_title"):
nod_logo = Image.open(nodlogo_loc)
with gr.Row():
@@ -346,69 +369,109 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
height=300,
)
with gr.Accordion(label="Stencil Options", open=False):
with gr.Accordion(label="Multistencil Options", open=False):
choices = ["None", "canny", "openpose", "scribble"]
def cnet_preview(
checked, model, input_image, index, stencils, images
):
if not checked:
stencils[index] = None
images[index] = None
return (None, stencils, images)
images[index] = input_image
stencils[index] = model
match model:
case "canny":
canny = CannyDetector()
result = canny(np.array(input_image), 100, 200)
return (
[Image.fromarray(result), result],
stencils,
images,
)
case "openpose":
openpose = OpenposeDetector()
result = openpose(np.array(input_image))
# TODO: This is just an empty canvas, need to draw the candidates (which are in result[1])
return (
[Image.fromarray(result[0]), result],
stencils,
images,
)
case _:
return (None, stencils, images)
with gr.Row():
use_stencil = gr.Dropdown(
elem_id="stencil_model",
label="Stencil model",
cnet_1 = gr.Checkbox(show_label=False)
cnet_1_model = gr.Dropdown(
label="Controlnet 1",
value="None",
choices=[
"None",
"canny",
"openpose",
"scribble",
"zoedepth",
choices=choices,
)
cnet_1_image = gr.Image(
source="upload",
tool=None,
type="pil",
)
cnet_1_output = gr.Gallery(
show_label=False,
object_fit="scale-down",
rows=1,
columns=1,
)
cnet_1.change(
fn=(
lambda a, b, c, s, i: cnet_preview(
a, b, c, 0, s, i
)
),
inputs=[
cnet_1,
cnet_1_model,
cnet_1_image,
stencils,
images,
],
outputs=[cnet_1_output, stencils, images],
)
def show_canvas(choice):
if choice == "scribble":
return (
gr.Slider.update(visible=True),
gr.Slider.update(visible=True),
gr.Button.update(visible=True),
)
else:
return (
gr.Slider.update(visible=False),
gr.Slider.update(visible=False),
gr.Button.update(visible=False),
)
def create_canvas(w, h):
return np.zeros(shape=(h, w, 3), dtype=np.uint8) + 255
with gr.Row():
canvas_width = gr.Slider(
label="Canvas Width",
minimum=256,
maximum=1024,
value=512,
step=1,
visible=False,
cnet_2 = gr.Checkbox(show_label=False)
cnet_2_model = gr.Dropdown(
label="Controlnet 2",
value="None",
choices=choices,
)
canvas_height = gr.Slider(
label="Canvas Height",
minimum=256,
maximum=1024,
value=512,
step=1,
visible=False,
cnet_2_image = gr.Image(
source="upload",
tool=None,
type="pil",
)
create_button = gr.Button(
label="Start",
value="Open drawing canvas!",
visible=False,
)
create_button.click(
fn=create_canvas,
inputs=[canvas_width, canvas_height],
outputs=[img2img_init_image],
)
use_stencil.change(
fn=show_canvas,
inputs=use_stencil,
outputs=[canvas_width, canvas_height, create_button],
cnet_2_output = gr.Gallery(
show_label=False,
object_fit="scale-down",
rows=1,
columns=1,
)
cnet_2.change(
fn=(
lambda a, b, c, s, i: cnet_preview(
a, b, c, 1, s, i
)
),
inputs=[
cnet_2,
cnet_2_model,
cnet_2_image,
stencils,
images,
],
outputs=[cnet_2_output, stencils, images],
)
control_mode = gr.Radio(
choices=["Prompt", "Balanced", "Controlnet"],
value="Balanced",
label="Control Mode",
)
with gr.Accordion(label="LoRA Options", open=False):
@@ -436,6 +499,11 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
label="HuggingFace Model ID",
lines=3,
)
with gr.Row():
lora_tags = gr.HTML(
value="<div><i>No LoRA selected</i></div>",
elem_classes="lora-tags",
)
with gr.Accordion(label="Advanced Options", open=False):
with gr.Row():
scheduler = gr.Dropdown(
@@ -610,7 +678,6 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
precision,
device,
max_length,
use_stencil,
save_metadata_to_json,
save_metadata_to_png,
lora_weights,
@@ -618,8 +685,17 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
ondemand,
repeatable_seeds,
resample_type,
control_mode,
stencils,
images,
],
outputs=[
img2img_gallery,
std_output,
img2img_status,
stencils,
images,
],
outputs=[img2img_gallery, std_output, img2img_status],
show_progress="minimal" if args.progress_bar else "none",
)
@@ -638,3 +714,10 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
fn=cancel_sd,
cancels=[prompt_submit, neg_prompt_submit, generate_click],
)
lora_weights.change(
fn=lora_changed,
inputs=[lora_weights],
outputs=[lora_tags],
queue=True,
)

View File

@@ -4,6 +4,7 @@ import time
import sys
import gradio as gr
from PIL import Image
from apps.stable_diffusion.web.ui.utils import (
available_devices,
nodlogo_loc,
@@ -13,6 +14,7 @@ from apps.stable_diffusion.web.ui.utils import (
predefined_paint_models,
cancel_sd,
)
from apps.stable_diffusion.web.ui.common_ui_events import lora_changed
from apps.stable_diffusion.src import (
args,
InpaintPipeline,
@@ -120,7 +122,7 @@ def inpaint_inf(
width,
device,
use_lora=args.use_lora,
use_stencil=None,
stencils=[],
ondemand=ondemand,
)
if (
@@ -319,6 +321,11 @@ with gr.Blocks(title="Inpainting") as inpaint_web:
label="HuggingFace Model ID",
lines=3,
)
with gr.Row():
lora_tags = gr.HTML(
value="<div><i>No LoRA selected</i></div>",
elem_classes="lora-tags",
)
with gr.Accordion(label="Advanced Options", open=False):
with gr.Row():
scheduler = gr.Dropdown(
@@ -518,3 +525,10 @@ with gr.Blocks(title="Inpainting") as inpaint_web:
fn=cancel_sd,
cancels=[prompt_submit, neg_prompt_submit, generate_click],
)
lora_weights.change(
fn=lora_changed,
inputs=[lora_weights],
outputs=[lora_tags],
queue=True,
)

View File

@@ -3,9 +3,8 @@ import torch
import time
import gradio as gr
from PIL import Image
import base64
from io import BytesIO
from fastapi.exceptions import HTTPException
from apps.stable_diffusion.web.ui.common_ui_events import lora_changed
from apps.stable_diffusion.web.ui.utils import (
available_devices,
nodlogo_loc,
@@ -122,7 +121,7 @@ def outpaint_inf(
width,
device,
use_lora=args.use_lora,
use_stencil=None,
stencils=[],
ondemand=ondemand,
)
if (
@@ -323,6 +322,11 @@ with gr.Blocks(title="Outpainting") as outpaint_web:
label="HuggingFace Model ID",
lines=3,
)
with gr.Row():
lora_tags = gr.HTML(
value="<div><i>No LoRA selected</i></div>",
elem_classes="lora-tags",
)
with gr.Accordion(label="Advanced Options", open=False):
with gr.Row():
scheduler = gr.Dropdown(
@@ -546,3 +550,10 @@ with gr.Blocks(title="Outpainting") as outpaint_web:
fn=cancel_sd,
cancels=[prompt_submit, neg_prompt_submit, generate_click],
)
lora_weights.change(
fn=lora_changed,
inputs=[lora_weights],
outputs=[lora_tags],
queue=True,
)

View File

@@ -91,7 +91,7 @@ with gr.Blocks() as outputgallery_web:
value=gallery_files.value,
visible=False,
show_label=True,
columns=2,
columns=4,
)
with gr.Column(scale=4):
@@ -204,6 +204,9 @@ with gr.Blocks() as outputgallery_web:
),
]
def on_image_columns_change(columns):
return gr.Gallery.update(columns=columns)
def on_select_subdir(subdir) -> list:
# evt.value is the subdirectory name
new_images = outputgallery_filenames(subdir)
@@ -365,53 +368,6 @@ with gr.Blocks() as outputgallery_web:
gr.update(),
)
# Unfortunately as of gradio 3.34.0 gr.update against Galleries doesn't
# support things set with .style, nor the elem_classes kwarg, so we have
# to directly set things up via JavaScript if we want the client to take
# notice of our changes to the number of columns after it decides to put
# them back to the original number when we change something
def js_set_columns_in_browser(timeout_length):
return f"""
(new_cols) => {{
setTimeout(() => {{
required_style = "auto ".repeat(new_cols).trim();
gallery = document.querySelector('#outputgallery_gallery .grid-container');
if (gallery) {{
gallery.style.gridTemplateColumns = required_style
}}
}}, {timeout_length});
return []; // prevents console error from gradio
}}
"""
# --- Wire handlers up to the actions
# Many actions reset the number of columns shown in the gallery on the
# browser end, so we have to set them back to what we think they should
# be after the initial action.
#
# None of the actions on this tab trigger inference, and we want the
# user to be able to do them whilst other tabs have ongoing inference
# running. Waiting in the queue behind inference jobs would mean the UI
# can't fully respond until the inference tasks complete,
# hence queue=False on all of these.
set_gallery_columns_immediate = dict(
fn=None,
inputs=[image_columns],
# gradio blanks the UI on Chrome on Linux on gallery select if
# I don't put an output here
outputs=[dev_null],
_js=js_set_columns_in_browser(0),
queue=False,
)
# setting columns after selecting a gallery item needs a real
# timeout length for the number of columns to actually be applied.
# Not really sure why, maybe something has to finish animating?
set_gallery_columns_delayed = dict(
set_gallery_columns_immediate, _js=js_set_columns_in_browser(250)
)
# clearing images when we need to completely change what's in the
# gallery avoids current images being shown replacing piecemeal and
# prevents weirdness and errors if the user selects an image during the
@@ -423,32 +379,35 @@ with gr.Blocks() as outputgallery_web:
queue=False,
)
image_columns.change(**set_gallery_columns_immediate)
subdirectories.select(**clear_gallery).then(
on_select_subdir,
[subdirectories],
[gallery_files, gallery, logo],
queue=False,
).then(**set_gallery_columns_immediate)
)
open_subdir.click(
on_open_subdir, inputs=[subdirectories], queue=False
).then(**set_gallery_columns_immediate)
open_subdir.click(on_open_subdir, inputs=[subdirectories], queue=False)
refresh.click(**clear_gallery).then(
on_refresh,
[subdirectories],
[subdirectories, subdirectory_paths, gallery_files, gallery, logo],
queue=False,
).then(**set_gallery_columns_immediate)
)
image_columns.change(
fn=on_image_columns_change,
inputs=[image_columns],
outputs=[gallery],
queue=False,
)
gallery.select(
on_select_image,
[gallery_files],
[outputgallery_filename, image_parameters],
queue=False,
).then(**set_gallery_columns_delayed)
)
outputgallery_filename.change(
on_outputgallery_filename_change,
@@ -477,7 +436,7 @@ with gr.Blocks() as outputgallery_web:
open_subdir,
],
queue=False,
).then(**set_gallery_columns_immediate)
)
# We should have been passed a list of components on other tabs that update
# when a new image has generated on that tab, so set things up so the user
@@ -489,4 +448,4 @@ with gr.Blocks() as outputgallery_web:
inputs=[subdirectories, subdirectory_paths, component],
outputs=[gallery_files, gallery, logo],
queue=False,
).then(**set_gallery_columns_immediate)
)

View File

@@ -6,6 +6,7 @@ from transformers import (
AutoModelForCausalLM,
)
from apps.stable_diffusion.web.ui.utils import available_devices
from shark.iree_utils.compile_utils import clean_device_info
from datetime import datetime as dt
import json
import sys
@@ -132,27 +133,6 @@ def get_default_config():
c.split_into_layers()
def clean_device_info(raw_device):
# return appropriate device and device_id for consumption by LLM pipeline
# Multiple devices only supported for vulkan and rocm (as of now).
# default device must be selected for all others
device_id = None
device = (
raw_device
if "=>" not in raw_device
else raw_device.split("=>")[1].strip()
)
if "://" in device:
device, device_id = device.split("://")
device_id = int(device_id) # using device index in webui
if device not in ["rocm", "vulkan"]:
device_id = None
return device, device_id
model_vmfb_key = ""
@@ -456,7 +436,7 @@ with gr.Blocks(title="Chatbot") as stablelm_chat:
json_view_button.click(
fn=view_json_file, inputs=[config_file], outputs=[json_view]
)
chatbot = gr.Chatbot(height=500)
chatbot = gr.Chatbot(elem_id="chatbot")
with gr.Row():
with gr.Column():
msg = gr.Textbox(

View File

@@ -0,0 +1,458 @@
import os
import torch
import time
import sys
import gradio as gr
from PIL import Image
from math import ceil
from apps.stable_diffusion.web.ui.utils import (
available_devices,
nodlogo_loc,
get_custom_model_path,
get_custom_model_files,
scheduler_list,
predefined_models,
cancel_sd,
)
from apps.stable_diffusion.web.utils.metadata import import_png_metadata
from apps.stable_diffusion.web.utils.common_label_calc import status_label
from apps.stable_diffusion.src import (
args,
Text2ImageSDXLPipeline,
get_schedulers,
set_init_device_flags,
utils,
save_output_img,
prompt_examples,
Image2ImagePipeline,
)
from apps.stable_diffusion.src.utils import (
get_generated_imgs_path,
get_generation_text_info,
)
# set initial values of iree_vulkan_target_triple, use_tuned and import_mlir.
init_iree_vulkan_target_triple = args.iree_vulkan_target_triple
init_iree_metal_target_platform = args.iree_metal_target_platform
init_use_tuned = args.use_tuned
init_import_mlir = args.import_mlir
def txt2img_sdxl_inf(
prompt: str,
negative_prompt: str,
height: int,
width: int,
steps: int,
guidance_scale: float,
seed: str | int,
batch_count: int,
batch_size: int,
scheduler: str,
model_id: str,
precision: str,
device: str,
max_length: int,
save_metadata_to_json: bool,
save_metadata_to_png: bool,
ondemand: bool,
repeatable_seeds: bool,
):
if precision != "fp16":
print("currently we support fp16 for SDXL")
precision = "fp16"
from apps.stable_diffusion.web.ui.utils import (
get_custom_model_pathfile,
get_custom_vae_or_lora_weights,
Config,
)
import apps.stable_diffusion.web.utils.global_obj as global_obj
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_utils import (
SD_STATE_CANCEL,
)
args.prompts = [prompt]
args.negative_prompts = [negative_prompt]
args.guidance_scale = guidance_scale
args.steps = steps
args.scheduler = scheduler
args.ondemand = ondemand
# set ckpt_loc and hf_model_id.
args.ckpt_loc = ""
args.hf_model_id = ""
args.custom_vae = ""
# .safetensor or .chkpt on the custom model path
if model_id in get_custom_model_files():
args.ckpt_loc = get_custom_model_pathfile(model_id)
# civitai download
elif "civitai" in model_id:
args.ckpt_loc = model_id
# either predefined or huggingface
else:
args.hf_model_id = model_id
# if custom_vae != "None":
# args.custom_vae = get_custom_model_pathfile(custom_vae, model="vae")
args.save_metadata_to_json = save_metadata_to_json
args.write_metadata_to_png = save_metadata_to_png
args.use_lora = ""
dtype = torch.float32 if precision == "fp32" else torch.half
cpu_scheduling = not scheduler.startswith("Shark")
new_config_obj = Config(
"txt2img_sdxl",
args.hf_model_id,
args.ckpt_loc,
args.custom_vae,
precision,
batch_size,
max_length,
height,
width,
device,
use_lora=args.use_lora,
use_stencil=None,
ondemand=ondemand,
)
if (
not global_obj.get_sd_obj()
or global_obj.get_cfg_obj() != new_config_obj
):
global_obj.clear_cache()
global_obj.set_cfg_obj(new_config_obj)
args.precision = precision
args.batch_count = batch_count
args.batch_size = batch_size
args.max_length = max_length
args.height = height
args.width = width
args.device = device.split("=>", 1)[1].strip()
args.iree_vulkan_target_triple = init_iree_vulkan_target_triple
args.iree_metal_target_platform = init_iree_metal_target_platform
args.use_tuned = init_use_tuned
args.import_mlir = init_import_mlir
args.img_path = None
set_init_device_flags()
model_id = (
args.hf_model_id
if args.hf_model_id
else "stabilityai/stable-diffusion-xl-base-1.0"
)
global_obj.set_schedulers(get_schedulers(model_id))
scheduler_obj = global_obj.get_scheduler(scheduler)
# For SDXL we set max_length as 77.
print("Setting max_length = 77")
max_length = 77
if global_obj.get_cfg_obj().ondemand:
print("Running txt2img in memory efficient mode.")
txt2img_sdxl_obj = Text2ImageSDXLPipeline.from_pretrained(
scheduler=scheduler_obj,
import_mlir=args.import_mlir,
model_id=args.hf_model_id,
ckpt_loc=args.ckpt_loc,
precision=precision,
max_length=max_length,
batch_size=batch_size,
height=height,
width=width,
use_base_vae=args.use_base_vae,
use_tuned=args.use_tuned,
custom_vae=args.custom_vae,
low_cpu_mem_usage=args.low_cpu_mem_usage,
debug=args.import_debug if args.import_mlir else False,
use_lora=args.use_lora,
use_quantize=args.use_quantize,
ondemand=global_obj.get_cfg_obj().ondemand,
)
global_obj.set_sd_obj(txt2img_sdxl_obj)
global_obj.set_sd_scheduler(scheduler)
start_time = time.time()
global_obj.get_sd_obj().log = ""
generated_imgs = []
text_output = ""
try:
seeds = utils.batch_seeds(seed, batch_count, repeatable_seeds)
except TypeError as error:
raise gr.Error(str(error)) from None
for current_batch in range(batch_count):
out_imgs = global_obj.get_sd_obj().generate_images(
prompt,
negative_prompt,
batch_size,
height,
width,
steps,
guidance_scale,
seeds[current_batch],
args.max_length,
dtype,
args.use_base_vae,
cpu_scheduling,
args.max_embeddings_multiples,
)
total_time = time.time() - start_time
text_output = get_generation_text_info(
seeds[: current_batch + 1], device
)
text_output += "\n" + global_obj.get_sd_obj().log
text_output += f"\nTotal image(s) generation time: {total_time:.4f}sec"
if global_obj.get_sd_status() == SD_STATE_CANCEL:
break
else:
save_output_img(out_imgs[0], seeds[current_batch])
generated_imgs.extend(out_imgs)
yield generated_imgs, text_output, status_label(
"Text-to-Image-SDXL",
current_batch + 1,
batch_count,
batch_size,
)
return generated_imgs, text_output, ""
with gr.Blocks(title="Text-to-Image-SDXL") as txt2img_sdxl_web:
with gr.Row(elem_id="ui_title"):
nod_logo = Image.open(nodlogo_loc)
with gr.Row():
with gr.Column(scale=1, elem_id="demo_title_outer"):
gr.Image(
value=nod_logo,
show_label=False,
interactive=False,
elem_id="top_logo",
width=150,
height=50,
)
with gr.Row(elem_id="ui_body"):
with gr.Row():
with gr.Column(scale=1, min_width=600):
with gr.Row():
with gr.Column(scale=10):
with gr.Row():
t2i_model_info = f"Custom Model Path: {str(get_custom_model_path())}"
txt2img_sdxl_custom_model = gr.Dropdown(
label=f"Models",
info="Select, or enter HuggingFace Model ID or Civitai model download URL",
elem_id="custom_model",
value=os.path.basename(args.ckpt_loc)
if args.ckpt_loc
else "stabilityai/stable-diffusion-xl-base-1.0",
choices=[
"stabilityai/stable-diffusion-xl-base-1.0"
],
allow_custom_value=True,
scale=2,
)
with gr.Group(elem_id="prompt_box_outer"):
prompt = gr.Textbox(
label="Prompt",
value=args.prompts[0],
lines=2,
elem_id="prompt_box",
)
negative_prompt = gr.Textbox(
label="Negative Prompt",
value=args.negative_prompts[0],
lines=2,
elem_id="negative_prompt_box",
)
with gr.Accordion(label="Advanced Options", open=False):
with gr.Row():
scheduler = gr.Dropdown(
elem_id="scheduler",
label="Scheduler",
value="DDIM",
choices=["DDIM"],
allow_custom_value=True,
visible=False,
)
with gr.Column():
save_metadata_to_png = gr.Checkbox(
label="Save prompt information to PNG",
value=args.write_metadata_to_png,
interactive=True,
)
save_metadata_to_json = gr.Checkbox(
label="Save prompt information to JSON file",
value=args.save_metadata_to_json,
interactive=True,
)
with gr.Row():
height = gr.Slider(
1024,
value=1024,
step=8,
label="Height",
visible=False,
)
width = gr.Slider(
1024,
value=1024,
step=8,
label="Width",
visible=False,
)
precision = gr.Radio(
label="Precision",
value="fp16",
choices=[
"fp16",
"fp32",
],
visible=False,
)
max_length = gr.Radio(
label="Max Length",
value=args.max_length,
choices=[
64,
77,
],
visible=False,
)
with gr.Row():
with gr.Column(scale=3):
steps = gr.Slider(
1, 100, value=args.steps, step=1, label="Steps"
)
with gr.Column(scale=3):
guidance_scale = gr.Slider(
0,
50,
value=args.guidance_scale,
step=0.1,
label="CFG Scale",
)
ondemand = gr.Checkbox(
value=args.ondemand,
label="Low VRAM",
interactive=True,
)
with gr.Row():
with gr.Column(scale=3):
batch_count = gr.Slider(
1,
100,
value=args.batch_count,
step=1,
label="Batch Count",
interactive=True,
)
with gr.Column(scale=3):
batch_size = gr.Slider(
1,
4,
value=args.batch_size,
step=1,
label="Batch Size",
interactive=True,
)
repeatable_seeds = gr.Checkbox(
args.repeatable_seeds,
label="Repeatable Seeds",
)
with gr.Row():
seed = gr.Textbox(
value=args.seed,
label="Seed",
info="An integer or a JSON list of integers, -1 for random",
)
device = gr.Dropdown(
elem_id="device",
label="Device",
value=available_devices[0],
choices=available_devices,
allow_custom_value=True,
)
with gr.Accordion(label="Prompt Examples!", open=False):
ex = gr.Examples(
examples=prompt_examples,
inputs=prompt,
cache_examples=False,
elem_id="prompt_examples",
)
with gr.Column(scale=1, min_width=600):
with gr.Group():
txt2img_sdxl_gallery = gr.Gallery(
label="Generated images",
show_label=False,
elem_id="gallery",
columns=[2],
object_fit="contain",
)
std_output = gr.Textbox(
value=f"{t2i_model_info}\n"
f"Images will be saved at "
f"{get_generated_imgs_path()}",
lines=1,
elem_id="std_output",
show_label=False,
)
txt2img_sdxl_status = gr.Textbox(visible=False)
with gr.Row():
stable_diffusion = gr.Button("Generate Image(s)")
random_seed = gr.Button("Randomize Seed")
random_seed.click(
lambda: -1,
inputs=[],
outputs=[seed],
queue=False,
)
stop_batch = gr.Button("Stop Batch")
with gr.Row():
blank_thing_for_row = None
kwargs = dict(
fn=txt2img_sdxl_inf,
inputs=[
prompt,
negative_prompt,
height,
width,
steps,
guidance_scale,
seed,
batch_count,
batch_size,
scheduler,
txt2img_sdxl_custom_model,
precision,
device,
max_length,
save_metadata_to_json,
save_metadata_to_png,
ondemand,
repeatable_seeds,
],
outputs=[txt2img_sdxl_gallery, std_output, txt2img_sdxl_status],
show_progress="minimal" if args.progress_bar else "none",
)
status_kwargs = dict(
fn=lambda bc, bs: status_label("Text-to-Image-SDXL", 0, bc, bs),
inputs=[batch_count, batch_size],
outputs=txt2img_sdxl_status,
)
prompt_submit = prompt.submit(**status_kwargs).then(**kwargs)
neg_prompt_submit = negative_prompt.submit(**status_kwargs).then(
**kwargs
)
generate_click = stable_diffusion.click(**status_kwargs).then(**kwargs)
stop_batch.click(
fn=cancel_sd,
cancels=[prompt_submit, neg_prompt_submit, generate_click],
)

View File

@@ -5,6 +5,7 @@ import sys
import gradio as gr
from PIL import Image
from math import ceil
from apps.stable_diffusion.web.ui.utils import (
available_devices,
nodlogo_loc,
@@ -15,6 +16,7 @@ from apps.stable_diffusion.web.ui.utils import (
predefined_models,
cancel_sd,
)
from apps.stable_diffusion.web.ui.common_ui_events import lora_changed
from apps.stable_diffusion.web.utils.metadata import import_png_metadata
from apps.stable_diffusion.web.utils.common_label_calc import status_label
from apps.stable_diffusion.src import (
@@ -124,7 +126,7 @@ def txt2img_inf(
width,
device,
use_lora=args.use_lora,
use_stencil=None,
stencils=[],
ondemand=ondemand,
)
if (
@@ -224,7 +226,7 @@ def txt2img_inf(
width,
device,
use_lora=args.use_lora,
use_stencil="None",
stencils=[],
ondemand=ondemand,
)
@@ -278,7 +280,7 @@ def txt2img_inf(
args.use_base_vae,
cpu_scheduling,
args.max_embeddings_multiples,
use_stencil="None",
stencils=[],
resample_type=resample_type,
)
total_time = time.time() - start_time
@@ -396,6 +398,11 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
label="HuggingFace Model ID",
lines=3,
)
with gr.Row():
lora_tags = gr.HTML(
value="<div><i>No LoRA selected</i></div>",
elem_classes="lora-tags",
)
with gr.Accordion(label="Advanced Options", open=False):
with gr.Row():
scheduler = gr.Dropdown(
@@ -689,3 +696,10 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
outputs=[scheduler],
queue=False,
)
lora_weights.change(
fn=lora_changed,
inputs=[lora_weights],
outputs=[lora_tags],
queue=True,
)

View File

@@ -3,6 +3,7 @@ import torch
import time
import gradio as gr
from PIL import Image
from apps.stable_diffusion.web.ui.utils import (
available_devices,
nodlogo_loc,
@@ -12,6 +13,7 @@ from apps.stable_diffusion.web.ui.utils import (
predefined_upscaler_models,
cancel_sd,
)
from apps.stable_diffusion.web.ui.common_ui_events import lora_changed
from apps.stable_diffusion.web.utils.common_label_calc import status_label
from apps.stable_diffusion.src import (
args,
@@ -118,7 +120,7 @@ def upscaler_inf(
args.width,
device,
use_lora=args.use_lora,
use_stencil=None,
stencils=[],
ondemand=ondemand,
)
if (
@@ -340,6 +342,11 @@ with gr.Blocks(title="Upscaler") as upscaler_web:
label="HuggingFace Model ID",
lines=3,
)
with gr.Row():
lora_tags = gr.HTML(
value="<div><i>No LoRA selected</i></div>",
elem_classes="lora-tags",
)
with gr.Accordion(label="Advanced Options", open=False):
with gr.Row():
scheduler = gr.Dropdown(
@@ -537,3 +544,10 @@ with gr.Blocks(title="Upscaler") as upscaler_web:
fn=cancel_sd,
cancels=[prompt_submit, neg_prompt_submit, generate_click],
)
lora_weights.change(
fn=lora_changed,
inputs=[lora_weights],
outputs=[lora_tags],
queue=True,
)

View File

@@ -1,10 +1,16 @@
import os
import sys
from apps.stable_diffusion.src import get_available_devices
import glob
import math
import json
import safetensors
from pathlib import Path
from apps.stable_diffusion.src import args
from dataclasses import dataclass
from enum import IntEnum
from apps.stable_diffusion.src import get_available_devices
import apps.stable_diffusion.web.utils.global_obj as global_obj
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_utils import (
SD_STATE_CANCEL,
@@ -24,10 +30,19 @@ class Config:
width: int
device: str
use_lora: str
use_stencil: str
stencils: list[str]
ondemand: str # should this be expecting a bool instead?
class HSLHue(IntEnum):
RED = 0
YELLOW = 60
GREEN = 120
CYAN = 180
BLUE = 240
MAGENTA = 300
custom_model_filetypes = (
"*.ckpt",
"*.safetensors",
@@ -161,6 +176,69 @@ def get_custom_vae_or_lora_weights(weights, hf_id, model):
return use_weight
def hsl_color(alpha: float, start, end):
b = (end - start) * (alpha if alpha > 0 else 0)
result = b + start
# Return a CSS HSL string
return f"hsl({math.floor(result)}, 80%, 35%)"
def get_lora_metadata(lora_filename):
# get the metadata from the file
filename = get_custom_model_pathfile(lora_filename, "lora")
with safetensors.safe_open(filename, framework="pt", device="cpu") as f:
metadata = f.metadata()
# guard clause for if there isn't any metadata
if not metadata:
return None
# metadata is a dictionary of strings, the values of the keys we're
# interested in are actually json, and need to be loaded as such
tag_frequencies = json.loads(metadata.get("ss_tag_frequency", str("{}")))
dataset_dirs = json.loads(metadata.get("ss_dataset_dirs", str("{}")))
tag_dirs = [dir for dir in tag_frequencies.keys()]
# gather the tag frequency information for all the datasets trained
all_frequencies = {}
for dataset in tag_dirs:
frequencies = sorted(
[entry for entry in tag_frequencies[dataset].items()],
reverse=True,
key=lambda x: x[1],
)
# get a figure for the total number of images processed for this dataset
# either then number actually listed or in its dataset_dir entry or
# the highest frequency's number if that doesn't exist
img_count = dataset_dirs.get(dir, {}).get(
"img_count", frequencies[0][1]
)
# add the dataset frequencies to the overall frequencies replacing the
# frequency counts on the tags with a percentage/ratio
all_frequencies.update(
[(entry[0], entry[1] / img_count) for entry in frequencies]
)
trained_model_id = " ".join(
[
metadata.get("ss_sd_model_hash", ""),
metadata.get("ss_sd_model_name", ""),
metadata.get("ss_base_model_version", ""),
]
).strip()
# return the topmost <count> of all frequencies in all datasets
return {
"model": trained_model_id,
"frequencies": sorted(
all_frequencies.items(), reverse=True, key=lambda x: x[1]
),
}
def cancel_sd():
# Try catch it, as gc can delete global_obj.sd_obj while switching model
try:

View File

@@ -5,11 +5,25 @@ from time import time
shark_tmp = os.path.join(os.getcwd(), "shark_tmp/")
def config_gradio_tmp_imgs_folder():
# create shark_tmp if it does not exist
if not os.path.exists(shark_tmp):
os.mkdir(shark_tmp)
def clear_tmp_mlir():
cleanup_start = time()
print(
"Clearing .mlir temporary files from a prior run. This may take some time..."
)
mlir_files = [
filename
for filename in os.listdir(shark_tmp)
if os.path.isfile(os.path.join(shark_tmp, filename))
and filename.endswith(".mlir")
]
for filename in mlir_files:
os.remove(shark_tmp + filename)
print(
f"Clearing .mlir temporary files took {time() - cleanup_start:.4f} seconds."
)
def clear_tmp_imgs():
# tell gradio to use a directory under shark_tmp for its temporary
# image files unless somewhere else has been set
if "GRADIO_TEMP_DIR" not in os.environ:
@@ -52,3 +66,12 @@ def config_gradio_tmp_imgs_folder():
)
else:
print("No temporary images files to clear.")
def config_tmp():
# create shark_tmp if it does not exist
if not os.path.exists(shark_tmp):
os.mkdir(shark_tmp)
clear_tmp_mlir()
clear_tmp_imgs()

View File

@@ -78,7 +78,10 @@ def test_loop(
os.mkdir("./test_images/golden")
get_inpaint_inputs()
hf_model_names = model_config_dicts[0].values()
tuned_options = ["--no-use_tuned", "--use_tuned"]
tuned_options = [
"--no-use_tuned",
"--use_tuned",
]
import_options = ["--import_mlir", "--no-import_mlir"]
prompt_text = "--prompt=cyberpunk forest by Salvador Dali"
inpaint_prompt_text = "--prompt=Face of a yellow cat, high resolution, sitting on a park bench"
@@ -112,6 +115,8 @@ def test_loop(
and use_tune == tuned_options[1]
):
continue
elif use_tune == tuned_options[1]:
continue
command = (
[
executable, # executable is the python from the venv used to run this

View File

@@ -22,33 +22,33 @@ This does mean however, that on a brand new fresh install of SHARK that has not
* Make sure you have suitable drivers for your graphics card installed. See the prerequisties section of the [README](https://github.com/nod-ai/SHARK#readme).
* Download the latest SHARK studio .exe from [here](https://github.com/nod-ai/SHARK/releases) or follow the instructions in the [README](https://github.com/nod-ai/SHARK#readme) for an advanced, Linux or Mac install.
* Run SHARK from terminal/PowerShell with the `--api` flag. Since koboldcpp also expects both CORS support and the image generator to be running on port `7860` rather than SHARK default of `8080`, also include both the `--api_cors_origin` flag with a suitable origin (use `="*"` to enable all origins) and `--server_port=7860` on the command line. (See the if you want to run SHARK on a different port)
* Run SHARK from terminal/PowerShell with the `--api` flag. Since koboldcpp also expects both CORS support and the image generator to be running on port `7860` rather than SHARK default of `8080`, also include both the `--api_accept_origin` flag with a suitable origin (use `="*"` to enable all origins) and `--server_port=7860` on the command line. (See the if you want to run SHARK on a different port)
```powershell
## Run the .exe in API mode, with CORS support, on the A1111 endpoint port:
.\node_ai_shark_studio_<date>_<ver>.exe --api --api_cors_origin="*" --server_port=7860
.\node_ai_shark_studio_<date>_<ver>.exe --api --api_accept_origin="*" --server_port=7860
## Run trom the base directory of a source clone of SHARK on Windows:
.\setup_venv.ps1
python .\apps\stable_diffusion\web\index.py --api --api_cors_origin="*" --server_port=7860
python .\apps\stable_diffusion\web\index.py --api --api_accept_origin="*" --server_port=7860
## Run a the base directory of a source clone of SHARK on Linux:
./setup_venv.sh
source shark.venv/bin/activate
python ./apps/stable_diffusion/web/index.py --api --api_cors_origin="*" --server_port=7860
python ./apps/stable_diffusion/web/index.py --api --api_accept_origin="*" --server_port=7860
## An example giving improved performance on AMD cards using vulkan, that runs on the same port as A1111
.\node_ai_shark_studio_20320901_2525.exe --api --api_cors_origin="*" --device_allocator="caching" --server_port=7860
.\node_ai_shark_studio_20320901_2525.exe --api --api_accept_origin="*" --device_allocator="caching" --server_port=7860
## Since the api respects most applicable SHARK command line arguments for options not specified,
## or currently unimplemented by API, there might be some you want to set, as listed in `--help`
.\node_ai_shark_studio_20320901_2525.exe --help
## For instance, the example above, but with a a custom VAE specified
.\node_ai_shark_studio_20320901_2525.exe --api --api_cors_origin="*" --device_allocator="caching" --server_port=7860 --custom_vae="clearvae_v23.safetensors"
.\node_ai_shark_studio_20320901_2525.exe --api --api_accept_origin="*" --device_allocator="caching" --server_port=7860 --custom_vae="clearvae_v23.safetensors"
## An example with multiple specific CORS origins
python apps/stable_diffusion/web/index.py --api --api_cors_origin="koboldcpp.example.com:7001" --api_cors_origin="koboldcpp.example.com:7002" --server_port=7860
python apps/stable_diffusion/web/index.py --api --api_accept_origin="koboldcpp.example.com:7001" --api_accept_origin="koboldcpp.example.com:7002" --server_port=7860
```
SHARK should start in server mode, and you should see something like this:

View File

@@ -26,7 +26,7 @@ sacremoses
sentencepiece
# web dependecies.
gradio
gradio==3.44.3
altair
scipy

View File

@@ -50,4 +50,8 @@ pefile
pyinstaller
# vicuna quantization
brevitas @ git+https://github.com/Xilinx/brevitas.git@56edf56a3115d5ac04f19837b388fd7d3b1ff7ea
brevitas @ git+https://github.com/Xilinx/brevitas.git@56edf56a3115d5ac04f19837b388fd7d3b1ff7ea
# For quantized GPTQ models
optimum
auto_gptq

View File

@@ -89,7 +89,7 @@ else {python -m venv .\shark.venv\}
python -m pip install --upgrade pip
pip install wheel
pip install -r requirements.txt
pip install --pre torch-mlir torch --extra-index-url https://download.pytorch.org/whl/nightly/cpu -f https://llvm.github.io/torch-mlir/package-index/
pip install --pre torch-mlir torchvision torch --extra-index-url https://download.pytorch.org/whl/nightly/cpu -f https://llvm.github.io/torch-mlir/package-index/
pip install --upgrade -f https://nod-ai.github.io/SRT/pip-release-links.html iree-compiler iree-runtime
Write-Host "Building SHARK..."
pip install -e . -f https://llvm.github.io/torch-mlir/package-index/ -f https://nod-ai.github.io/SRT/pip-release-links.html

View File

@@ -31,60 +31,64 @@ from .benchmark_utils import *
# Get the iree-compile arguments given device.
def get_iree_device_args(device, extra_args=[]):
print("Configuring for device:" + device)
device_uri = device.split("://")
if len(device_uri) > 1:
if device_uri[0] not in ["vulkan", "rocm"]:
print(
f"Specific device selection only supported for vulkan and rocm."
f"Proceeding with {device} as device."
)
# device_uri can be device_num or device_path.
# assuming number of devices for a single driver will be not be >99
if len(device_uri[1]) <= 2:
# expected to be device index in range 0 - 99
device_num = int(device_uri[1])
else:
# expected to be device path
device_num = device_uri[1]
else:
device_num = 0
device, device_num = clean_device_info(device)
if "cpu" in device:
from shark.iree_utils.cpu_utils import get_iree_cpu_args
data_tiling_flag = ["--iree-opt-data-tiling"]
u_kernel_flag = ["--iree-llvmcpu-enable-microkernels"]
u_kernel_flag = ["--iree-llvmcpu-enable-ukernels"]
stack_size_flag = ["--iree-llvmcpu-stack-allocation-limit=256000"]
return (
get_iree_cpu_args()
+ data_tiling_flag
+ u_kernel_flag
+ stack_size_flag
+ ["--iree-global-opt-enable-quantized-matmul-reassociation"]
)
if device_uri[0] == "cuda":
if device == "cuda":
from shark.iree_utils.gpu_utils import get_iree_gpu_args
return get_iree_gpu_args()
if device_uri[0] == "vulkan":
if device == "vulkan":
from shark.iree_utils.vulkan_utils import get_iree_vulkan_args
return get_iree_vulkan_args(
device_num=device_num, extra_args=extra_args
)
if device_uri[0] == "metal":
if device == "metal":
from shark.iree_utils.metal_utils import get_iree_metal_args
return get_iree_metal_args(extra_args=extra_args)
if device_uri[0] == "rocm":
if device == "rocm":
from shark.iree_utils.gpu_utils import get_iree_rocm_args
return get_iree_rocm_args(device_num=device_num, extra_args=extra_args)
return []
def clean_device_info(raw_device):
# return appropriate device and device_id for consumption by Studio pipeline
# Multiple devices only supported for vulkan and rocm (as of now).
# default device must be selected for all others
device_id = None
device = (
raw_device
if "=>" not in raw_device
else raw_device.split("=>")[1].strip()
)
if "://" in device:
device, device_id = device.split("://")
if len(device_id) <= 2:
device_id = int(device_id)
if device not in ["rocm", "vulkan"]:
device_id = ""
if device in ["rocm", "vulkan"] and device_id == None:
device_id = 0
return device, device_id
# Get the iree-compiler arguments given frontend.
def get_iree_frontend_args(frontend):
if frontend in ["torch", "pytorch", "linalg", "tm_tensor"]:

View File

@@ -95,6 +95,7 @@ def get_rocm_device_arch(device_num=0, extra_args=[]):
print("could not execute `iree-run-module --dump_devices=rocm`")
if dump_device_info is not None:
device_num = 0 if device_num is None else device_num
device_arch_pairs = get_devices_info_from_dump(dump_device_info[0])
if len(device_arch_pairs) > device_num: # can find arch in the list
arch_in_device_dump = device_arch_pairs[device_num][1]
@@ -103,7 +104,7 @@ def get_rocm_device_arch(device_num=0, extra_args=[]):
print(f"Found ROCm device arch : {arch_in_device_dump}")
return arch_in_device_dump
default_rocm_arch = "gfx_1100"
default_rocm_arch = "gfx1100"
print(
"Did not find ROCm architecture from `--iree-rocm-target-chip` flag"
"\n or from `iree-run-module --dump_devices=rocm` command."

View File

@@ -38,15 +38,24 @@ def get_all_vulkan_devices():
@functools.cache
def get_vulkan_device_name(device_num=0):
vulkaninfo_list = get_all_vulkan_devices()
if len(vulkaninfo_list) == 0:
raise ValueError("No device name found in VulkanInfo!")
if len(vulkaninfo_list) > 1:
print("Following devices found:")
for i, dname in enumerate(vulkaninfo_list):
print(f"{i}. {dname}")
print(f"Choosing device: {vulkaninfo_list[device_num]}")
return vulkaninfo_list[device_num]
if isinstance(device_num, int):
vulkaninfo_list = get_all_vulkan_devices()
if len(vulkaninfo_list) == 0:
raise ValueError("No device name found in VulkanInfo!")
if len(vulkaninfo_list) > 1:
print("Following devices found:")
for i, dname in enumerate(vulkaninfo_list):
print(f"{i}. {dname}")
print(f"Choosing device: vulkan://{device_num}")
vulkan_device_name = vulkaninfo_list[device_num]
else:
from iree.runtime import get_driver
vulkan_device_driver = get_driver(device_num)
vulkan_device_name = vulkan_device_driver.query_available_devices()[0]
print(vulkan_device_name)
return vulkan_device_name
def get_os_name():

View File

@@ -800,15 +800,17 @@ def save_mlir(
model_name,
mlir_dialect="linalg",
frontend="torch",
dir=tempfile.gettempdir(),
dir="",
):
model_name_mlir = (
model_name + "_" + frontend + "_" + mlir_dialect + ".mlir"
)
if dir == "":
dir = tempfile.gettempdir()
dir = os.path.join(".", "shark_tmp")
mlir_path = os.path.join(dir, model_name_mlir)
print(f"saving {model_name_mlir} to {dir}")
if not os.path.exists(dir):
os.makedirs(dir)
if frontend == "torch":
with open(mlir_path, "wb") as mlir_file:
mlir_file.write(mlir_module)

View File

@@ -1,21 +1,19 @@
bert-base-uncased,linalg,torch,1e-2,1e-3,default,None,False,True,False,"",""
bert-base-uncased_fp16,linalg,torch,1e-1,1e-1,default,None,True,True,True,"",""
bert-large-uncased,linalg,torch,1e-2,1e-3,default,None,False,True,False,"",""
facebook/deit-small-distilled-patch16-224,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,False,True,False,"Fails during iree-compile.",""
google/vit-base-patch16-224,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,False,True,False,"https://github.com/nod-ai/SHARK/issues/311",""
microsoft/beit-base-patch16-224-pt22k-ft22k,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,False,True,False,"https://github.com/nod-ai/SHARK/issues/390","macos"
microsoft/MiniLM-L12-H384-uncased,linalg,torch,1e-2,1e-3,default,None,False,True,False,"",""
google/mobilebert-uncased,linalg,torch,1e-2,1e-3,default,None,False,True,False,"https://github.com/nod-ai/SHARK/issues/344","macos"
mobilenet_v3_small,linalg,torch,1e-1,1e-2,default,nhcw-nhwc,True,True,True,"https://github.com/nod-ai/SHARK/issues/388, https://github.com/nod-ai/SHARK/issues/1487","macos"
bert-base-uncased,linalg,torch,1e-2,1e-3,default,None,False,False,False,"",""
bert-large-uncased,linalg,torch,1e-2,1e-3,default,None,False,False,False,"",""
facebook/deit-small-distilled-patch16-224,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,False,False,False,"Fails during iree-compile.",""
google/vit-base-patch16-224,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,True,True,True,"https://github.com/nod-ai/SHARK/issues/311",""
microsoft/beit-base-patch16-224-pt22k-ft22k,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,False,False,False,"https://github.com/nod-ai/SHARK/issues/390","macos"
microsoft/MiniLM-L12-H384-uncased,linalg,torch,1e-2,1e-3,default,None,False,False,False,"",""
google/mobilebert-uncased,linalg,torch,1e-2,1e-3,default,None,False,False,False,"https://github.com/nod-ai/SHARK/issues/344","macos"
mobilenet_v3_small,linalg,torch,1e-1,1e-2,default,nhcw-nhwc,False,False,False,"https://github.com/nod-ai/SHARK/issues/388, https://github.com/nod-ai/SHARK/issues/1487","macos"
nvidia/mit-b0,linalg,torch,1e-2,1e-3,default,None,True,True,True,"https://github.com/nod-ai/SHARK/issues/343,https://github.com/nod-ai/SHARK/issues/1487","macos"
resnet101,linalg,torch,1e-2,1e-3,default,nhcw-nhwc/img2col,True,False,False,"","macos"
resnet18,linalg,torch,1e-2,1e-3,default,None,True,True,False,"","macos"
resnet101,linalg,torch,1e-2,1e-3,default,nhcw-nhwc/img2col,True,True,True,"","macos"
resnet18,linalg,torch,1e-2,1e-3,default,None,True,True,True,"","macos"
resnet50,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,False,False,False,"","macos"
resnet50_fp16,linalg,torch,1e-2,1e-2,default,nhcw-nhwc/img2col,True,True,True,"Numerics issues, awaiting cuda-independent fp16 integration",""
squeezenet1_0,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,False,False,False,"","macos"
wide_resnet50_2,linalg,torch,1e-2,1e-3,default,nhcw-nhwc/img2col,True,False,False,"","macos"
mnasnet1_0,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,True,True,True,"","macos"
wide_resnet50_2,linalg,torch,1e-2,1e-3,default,nhcw-nhwc/img2col,True,True,True,"","macos"
mnasnet1_0,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,False,False,False,"","macos"
efficientnet_b0,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,True,True,True,"https://github.com/nod-ai/SHARK/issues/1487","macos"
efficientnet_b7,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,True,True,True,"https://github.com/nod-ai/SHARK/issues/1487","macos"
t5-base,linalg,torch,1e-2,1e-3,default,None,True,True,True,"Inputs for seq2seq models in torch currently unsupported.","macos"
t5-large,linalg,torch,1e-2,1e-3,default,None,True,True,True,"Inputs for seq2seq models in torch currently unsupported","macos"
t5-base,linalg,torch,1e-2,1e-3,default,None,True,True,True,"","macos"
t5-large,linalg,torch,1e-2,1e-3,default,None,True,True,True,"","macos"
1 bert-base-uncased linalg torch 1e-2 1e-3 default None False True False False
2 bert-base-uncased_fp16 bert-large-uncased linalg torch 1e-1 1e-2 1e-1 1e-3 default None True False True True False False
3 bert-large-uncased facebook/deit-small-distilled-patch16-224 linalg torch 1e-2 1e-3 default None nhcw-nhwc False True False False Fails during iree-compile.
4 facebook/deit-small-distilled-patch16-224 google/vit-base-patch16-224 linalg torch 1e-2 1e-3 default nhcw-nhwc False True True False True True Fails during iree-compile. https://github.com/nod-ai/SHARK/issues/311
5 google/vit-base-patch16-224 microsoft/beit-base-patch16-224-pt22k-ft22k linalg torch 1e-2 1e-3 default nhcw-nhwc False True False False https://github.com/nod-ai/SHARK/issues/311 https://github.com/nod-ai/SHARK/issues/390 macos
6 microsoft/beit-base-patch16-224-pt22k-ft22k microsoft/MiniLM-L12-H384-uncased linalg torch 1e-2 1e-3 default nhcw-nhwc None False True False False https://github.com/nod-ai/SHARK/issues/390 macos
7 microsoft/MiniLM-L12-H384-uncased google/mobilebert-uncased linalg torch 1e-2 1e-3 default None False True False False https://github.com/nod-ai/SHARK/issues/344 macos
8 google/mobilebert-uncased mobilenet_v3_small linalg torch 1e-2 1e-1 1e-3 1e-2 default None nhcw-nhwc False True False False https://github.com/nod-ai/SHARK/issues/344 https://github.com/nod-ai/SHARK/issues/388, https://github.com/nod-ai/SHARK/issues/1487 macos
mobilenet_v3_small linalg torch 1e-1 1e-2 default nhcw-nhwc True True True https://github.com/nod-ai/SHARK/issues/388, https://github.com/nod-ai/SHARK/issues/1487 macos
9 nvidia/mit-b0 linalg torch 1e-2 1e-3 default None True True True True https://github.com/nod-ai/SHARK/issues/343,https://github.com/nod-ai/SHARK/issues/1487 macos
10 resnet101 linalg torch 1e-2 1e-3 default nhcw-nhwc/img2col True False False True True macos
11 resnet18 linalg torch 1e-2 1e-3 default None True True False True True macos
12 resnet50 linalg torch 1e-2 1e-3 default nhcw-nhwc False False False False macos
resnet50_fp16 linalg torch 1e-2 1e-2 default nhcw-nhwc/img2col True True True Numerics issues, awaiting cuda-independent fp16 integration
13 squeezenet1_0 linalg torch 1e-2 1e-3 default nhcw-nhwc False False False False macos
14 wide_resnet50_2 linalg torch 1e-2 1e-3 default nhcw-nhwc/img2col True False False True True macos
15 mnasnet1_0 linalg torch 1e-2 1e-3 default nhcw-nhwc True False True True False False macos
16 efficientnet_b0 linalg torch 1e-2 1e-3 default nhcw-nhwc True True True True https://github.com/nod-ai/SHARK/issues/1487 macos
17 efficientnet_b7 linalg torch 1e-2 1e-3 default nhcw-nhwc True True True True https://github.com/nod-ai/SHARK/issues/1487 macos
18 t5-base linalg torch 1e-2 1e-3 default None True True True True Inputs for seq2seq models in torch currently unsupported. macos
19 t5-large linalg torch 1e-2 1e-3 default None True True True True Inputs for seq2seq models in torch currently unsupported macos

View File

@@ -50,7 +50,7 @@ def save_torch_model(torch_model_list, local_tank_cache, import_args):
is_decompose = row[5]
tracing_required = False if tracing_required == "False" else True
is_dynamic = False if is_dynamic == "False" else True
is_dynamic = False
print("generating artifacts for: " + torch_model_name)
model = None
input = None
@@ -104,7 +104,7 @@ def save_torch_model(torch_model_list, local_tank_cache, import_args):
model_name=torch_model_name,
mlir_type=mlir_type,
is_dynamic=False,
tracing_required=tracing_required,
tracing_required=True,
)
else:
mlir_importer = SharkImporter(
@@ -114,7 +114,7 @@ def save_torch_model(torch_model_list, local_tank_cache, import_args):
)
mlir_importer.import_debug(
is_dynamic=False,
tracing_required=tracing_required,
tracing_required=True,
dir=torch_model_dir,
model_name=torch_model_name,
mlir_type=mlir_type,
@@ -123,7 +123,7 @@ def save_torch_model(torch_model_list, local_tank_cache, import_args):
if is_dynamic:
mlir_importer.import_debug(
is_dynamic=True,
tracing_required=tracing_required,
tracing_required=True,
dir=torch_model_dir,
model_name=torch_model_name + "_dynamic",
mlir_type=mlir_type,