mirror of
https://github.com/nod-ai/SHARK-Studio.git
synced 2026-04-20 03:00:34 -04:00
Compare commits
4 Commits
20230817.8
...
20230821.9
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b87efe7686 | ||
|
|
82b462de3a | ||
|
|
d8f0f7bade | ||
|
|
79bd0b84a1 |
@@ -283,7 +283,7 @@ class VicunaBase(SharkLLMBase):
|
||||
vnames.append(vname)
|
||||
if "true" not in vname:
|
||||
global_vars.append(
|
||||
f"ml_program.global public @{vname}({vbody}) : {fixed_vdtype}"
|
||||
f"ml_program.global private @{vname}({vbody}) : {fixed_vdtype}"
|
||||
)
|
||||
global_var_loading1.append(
|
||||
f"\t\t%{vname} = ml_program.global_load_const @{vname} : {fixed_vdtype}"
|
||||
@@ -293,7 +293,7 @@ class VicunaBase(SharkLLMBase):
|
||||
)
|
||||
else:
|
||||
global_vars.append(
|
||||
f"ml_program.global public @{vname}({vbody}) : i1"
|
||||
f"ml_program.global private @{vname}({vbody}) : i1"
|
||||
)
|
||||
global_var_loading1.append(
|
||||
f"\t\t%{vname} = ml_program.global_load_const @{vname} : i1"
|
||||
|
||||
@@ -34,7 +34,7 @@ from PIL import Image
|
||||
from tqdm.auto import tqdm
|
||||
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
||||
from diffusers.loaders import AttnProcsLayers
|
||||
from diffusers.models.cross_attention import LoRACrossAttnProcessor
|
||||
from diffusers.models.attention_processor import LoRAXFormersAttnProcessor
|
||||
|
||||
import torch_mlir
|
||||
from torch_mlir.dynamo import make_simple_dynamo_backend
|
||||
@@ -287,7 +287,7 @@ def lora_train(
|
||||
block_id = int(name[len("down_blocks.")])
|
||||
hidden_size = unet.config.block_out_channels[block_id]
|
||||
|
||||
lora_attn_procs[name] = LoRACrossAttnProcessor(
|
||||
lora_attn_procs[name] = LoRAXFormersAttnProcessor(
|
||||
hidden_size=hidden_size,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
)
|
||||
|
||||
@@ -180,6 +180,7 @@ class SharkifyStableDiffusionModel:
|
||||
"vae",
|
||||
"vae_encode",
|
||||
"stencil_adaptor",
|
||||
"stencil_adaptor_512",
|
||||
]
|
||||
index = 0
|
||||
for model in sub_model_list:
|
||||
@@ -449,7 +450,7 @@ class SharkifyStableDiffusionModel:
|
||||
)
|
||||
return shark_controlled_unet, controlled_unet_mlir
|
||||
|
||||
def get_control_net(self):
|
||||
def get_control_net(self, use_large=False):
|
||||
class StencilControlNetModel(torch.nn.Module):
|
||||
def __init__(
|
||||
self, model_id=self.use_stencil, low_cpu_mem_usage=False
|
||||
@@ -497,17 +498,34 @@ class SharkifyStableDiffusionModel:
|
||||
is_f16 = True if self.precision == "fp16" else False
|
||||
|
||||
inputs = tuple(self.inputs["stencil_adaptor"])
|
||||
if use_large:
|
||||
pad = (0, 0) * (len(inputs[2].shape) - 2)
|
||||
pad = pad + (0, 512 - inputs[2].shape[1])
|
||||
inputs = (
|
||||
inputs[0],
|
||||
inputs[1],
|
||||
torch.nn.functional.pad(inputs[2], pad),
|
||||
inputs[3],
|
||||
)
|
||||
save_dir = os.path.join(
|
||||
self.sharktank_dir, self.model_name["stencil_adaptor_512"]
|
||||
)
|
||||
else:
|
||||
save_dir = os.path.join(
|
||||
self.sharktank_dir, self.model_name["stencil_adaptor"]
|
||||
)
|
||||
input_mask = [True, True, True, True]
|
||||
model_name = "stencil_adaptor" if use_large else "stencil_adaptor_512"
|
||||
shark_cnet, cnet_mlir = compile_through_fx(
|
||||
scnet,
|
||||
inputs,
|
||||
extended_model_name=self.model_name["stencil_adaptor"],
|
||||
extended_model_name=self.model_name[model_name],
|
||||
is_f16=is_f16,
|
||||
f16_input_mask=input_mask,
|
||||
use_tuned=self.use_tuned,
|
||||
extra_args=get_opt_flags("unet", precision=self.precision),
|
||||
base_model_id=self.base_model_id,
|
||||
model_name="stencil_adaptor",
|
||||
model_name=model_name,
|
||||
precision=self.precision,
|
||||
return_mlir=self.return_mlir,
|
||||
)
|
||||
@@ -847,12 +865,14 @@ class SharkifyStableDiffusionModel:
|
||||
except Exception as e:
|
||||
sys.exit(e)
|
||||
|
||||
def controlnet(self):
|
||||
def controlnet(self, use_large=False):
|
||||
try:
|
||||
self.inputs["stencil_adaptor"] = self.get_input_info_for(
|
||||
base_models["stencil_adaptor"]
|
||||
)
|
||||
compiled_stencil_adaptor, controlnet_mlir = self.get_control_net()
|
||||
compiled_stencil_adaptor, controlnet_mlir = self.get_control_net(
|
||||
use_large=use_large
|
||||
)
|
||||
|
||||
check_compilation(compiled_stencil_adaptor, "Stencil")
|
||||
if self.return_mlir:
|
||||
|
||||
@@ -58,6 +58,7 @@ class StencilPipeline(StableDiffusionPipeline):
|
||||
):
|
||||
super().__init__(scheduler, sd_model, import_mlir, use_lora, ondemand)
|
||||
self.controlnet = None
|
||||
self.controlnet_512 = None
|
||||
|
||||
def load_controlnet(self):
|
||||
if self.controlnet is not None:
|
||||
@@ -68,6 +69,15 @@ class StencilPipeline(StableDiffusionPipeline):
|
||||
del self.controlnet
|
||||
self.controlnet = None
|
||||
|
||||
def load_controlnet_512(self):
|
||||
if self.controlnet_512 is not None:
|
||||
return
|
||||
self.controlnet_512 = self.sd_model.controlnet(use_large=True)
|
||||
|
||||
def unload_controlnet_512(self):
|
||||
del self.controlnet_512
|
||||
self.controlnet_512 = None
|
||||
|
||||
def prepare_latents(
|
||||
self,
|
||||
batch_size,
|
||||
@@ -111,8 +121,12 @@ class StencilPipeline(StableDiffusionPipeline):
|
||||
latent_history = [latents]
|
||||
text_embeddings = torch.from_numpy(text_embeddings).to(dtype)
|
||||
text_embeddings_numpy = text_embeddings.detach().numpy()
|
||||
self.load_unet()
|
||||
self.load_controlnet()
|
||||
if text_embeddings.shape[1] <= self.model_max_length:
|
||||
self.load_unet()
|
||||
self.load_controlnet()
|
||||
else:
|
||||
self.load_unet_512()
|
||||
self.load_controlnet_512()
|
||||
for i, t in tqdm(enumerate(total_timesteps)):
|
||||
step_start_time = time.time()
|
||||
timestep = torch.tensor([t]).to(dtype)
|
||||
@@ -135,16 +149,28 @@ class StencilPipeline(StableDiffusionPipeline):
|
||||
).to(dtype)
|
||||
else:
|
||||
latent_model_input_1 = latent_model_input
|
||||
control = self.controlnet(
|
||||
"forward",
|
||||
(
|
||||
latent_model_input_1,
|
||||
timestep,
|
||||
text_embeddings,
|
||||
controlnet_hint,
|
||||
),
|
||||
send_to_host=False,
|
||||
)
|
||||
if text_embeddings.shapes[1] <= self.model_max_length:
|
||||
control = self.controlnet(
|
||||
"forward",
|
||||
(
|
||||
latent_model_input_1,
|
||||
timestep,
|
||||
text_embeddings,
|
||||
controlnet_hint,
|
||||
),
|
||||
send_to_host=False,
|
||||
)
|
||||
else:
|
||||
control = self.controlnet_512(
|
||||
"forward",
|
||||
(
|
||||
latent_model_input_1,
|
||||
timestep,
|
||||
text_embeddings,
|
||||
controlnet_hint,
|
||||
),
|
||||
send_to_host=False,
|
||||
)
|
||||
timestep = timestep.detach().numpy()
|
||||
# Profiling Unet.
|
||||
profile_device = start_profiling(file_path="unet.rdc")
|
||||
@@ -191,7 +217,9 @@ class StencilPipeline(StableDiffusionPipeline):
|
||||
|
||||
if self.ondemand:
|
||||
self.unload_unet()
|
||||
self.unload_unet_512()
|
||||
self.unload_controlnet()
|
||||
self.unload_controlnet_512()
|
||||
avg_step_time = step_time_sum / len(total_timesteps)
|
||||
self.log += f"\nAverage step time: {avg_step_time}ms/it"
|
||||
|
||||
|
||||
@@ -19,7 +19,7 @@ parameterized
|
||||
|
||||
# Add transformers, diffusers and scipy since it most commonly used
|
||||
transformers
|
||||
diffusers==0.19.3
|
||||
diffusers
|
||||
#accelerate is now required for diffusers import from ckpt.
|
||||
accelerate
|
||||
scipy
|
||||
|
||||
@@ -146,7 +146,7 @@ if [[ $(uname -s) = 'Linux' && ! -z "${BENCHMARK}" ]]; then
|
||||
fi
|
||||
|
||||
if [[ -z "${NO_BREVITAS}" ]]; then
|
||||
$PYTHON -m pip install git+https://github.com/Xilinx/brevitas.git@llm
|
||||
$PYTHON -m pip install git+https://github.com/Xilinx/brevitas.git@dev
|
||||
fi
|
||||
|
||||
if [[ -z "${CONDA_PREFIX}" && "$SKIP_VENV" != "1" ]]; then
|
||||
|
||||
Reference in New Issue
Block a user