Compare commits

..

11 Commits

Author SHA1 Message Date
jinchen62
534de05791 Update precision check for vicuna (#1610) 2023-06-29 16:16:33 -05:00
Daniel Garvey
5779e8c039 int4/int8 vicuna download support (#1609)
* set task_topology_max_group to cpu_count

by default. Can be overriden with a flag of the same str

* add download for int4/int8 mlir
2023-06-29 13:35:51 -07:00
Abhishek Varma
d496053590 [SHARK] Add a compile API to use for quick testing of inference (#1606) 2023-06-28 08:40:28 -07:00
gpetters94
6274a813c9 Add unet512 support for the other StableDiffusion pipelines (#1602) 2023-06-27 12:28:57 -07:00
Gaurav Shukla
1d6a1f9f8a [vicuna] Add tokens streaming(step=3) (#1600)
Signed-off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-06-27 08:59:27 -07:00
Daniel Garvey
75672c0e28 set task_topology_max_group to cpu_count (#1594)
by default. Can be overriden with a flag of the same str
2023-06-26 14:54:06 -07:00
Prashant Kumar
74a7202173 Make the tensors contiguous. 2023-06-26 17:29:54 +05:30
Prashant Kumar
27a08735db Add the shark backend for torch.compile API. (#1596) 2023-06-26 03:53:32 -07:00
Stefan Kapusniak
eaa49cce17 UI/App - Allow text selection (#1593)
* When run in app mode on windows, allows selection of text from
non-input controls, which is the same behaviour as web mode.
2023-06-26 02:16:53 -07:00
powderluv
10657d6fb1 Disable upx 2023-06-25 07:28:52 -07:00
Stefan Kapusniak
e3ab844cd1 Fix output gallery for csv format inc. VAE & LoRA (#1591) 2023-06-24 06:20:53 -07:00
30 changed files with 444 additions and 258 deletions

View File

@@ -38,8 +38,10 @@ class Vicuna(SharkLLMBase):
super().__init__(model_name, hf_model_path, max_num_tokens)
self.max_sequence_length = 256
self.device = device
if precision in ["int4", "int8"]:
print("int4 and int8 are not supported yet, using fp32")
if not load_mlir_from_shark_tank and precision in ["int4", "int8"]:
print(
"int4 and int8 are only available from SHARK tank, please set --load_mlir_from_shark_tank, using fp32 now"
)
precision = "fp32"
self.precision = precision
self.first_vicuna_vmfb_path = first_vicuna_vmfb_path
@@ -103,8 +105,8 @@ class Vicuna(SharkLLMBase):
else:
mlir_generated = False
if self.load_mlir_from_shark_tank:
if self.precision in ["fp32", "fp16"]:
# download MLIR from shark_tank for fp32/fp16
if self.precision in ["fp32", "fp16", "int8", "int4"]:
# download MLIR from shark_tank
download_public_file(
f"gs://shark_tank/vicuna/unsharded/mlir/{self.first_vicuna_mlir_path.name}",
self.first_vicuna_mlir_path.absolute(),
@@ -121,7 +123,7 @@ class Vicuna(SharkLLMBase):
)
else:
print(
f"Only fp32 and fp16 mlir added to tank, generating {self.precision} mlir on device."
f"Only fp32/fp16/int8/int4 mlir added to tank, generating {self.precision} mlir on device."
)
if not mlir_generated:
@@ -245,8 +247,8 @@ class Vicuna(SharkLLMBase):
else:
mlir_generated = False
if self.load_mlir_from_shark_tank:
if self.precision in ["fp32", "fp16"]:
# download MLIR from shark_tank for fp32/fp16
if self.precision in ["fp32", "fp16", "int8", "int4"]:
# download MLIR from shark_tank
download_public_file(
f"gs://shark_tank/vicuna/unsharded/mlir/{self.second_vicuna_mlir_path.name}",
self.second_vicuna_mlir_path.absolute(),
@@ -263,7 +265,7 @@ class Vicuna(SharkLLMBase):
)
else:
print(
"Only fp32 mlir added to tank, generating mlir on device."
"Only fp32/fp16/int8/int4 mlir added to tank, generating mlir on device."
)
if not mlir_generated:
@@ -439,6 +441,14 @@ class Vicuna(SharkLLMBase):
# return tuple of shark_modules once mem is supported
# return fvic_shark_model, svic_shark_model
def decode_tokens(self, res_tokens):
for i in range(len(res_tokens)):
if type(res_tokens[i]) != int:
res_tokens[i] = int(res_tokens[i][0])
res_str = self.tokenizer.decode(res_tokens)
return res_str
def generate(self, prompt, cli=False):
# TODO: refactor for cleaner integration
import gc
@@ -448,7 +458,6 @@ class Vicuna(SharkLLMBase):
self.first_vic = self.compile_first_vicuna()
if self.second_vic == None:
self.second_vic = self.compile_second_vicuna()
res = []
res_tokens = []
params = {
"prompt": prompt,
@@ -464,8 +473,8 @@ class Vicuna(SharkLLMBase):
logits = generated_token_op["logits"]
pkv = generated_token_op["pkv"]
detok = generated_token_op["detok"]
yield detok
res.append(detok)
res_tokens.append(token)
if cli:
print(f"Assistant: {detok}", end=" ", flush=True)
@@ -498,25 +507,24 @@ class Vicuna(SharkLLMBase):
break
res_tokens.append(token)
if detok == "<0x0A>":
res.append("\n")
if cli:
print("\n", end="", flush=True)
else:
res.append(detok)
if cli:
print(f"{detok}", end=" ", flush=True)
if len(res_tokens) % 3 == 0:
part_str = self.decode_tokens(res_tokens)
yield part_str
if self.device == "cuda":
del sec_vic, pkv, logits
torch.cuda.empty_cache()
gc.collect()
for i in range(len(res_tokens)):
if type(res_tokens[i]) != int:
res_tokens[i] = int(res_tokens[i][0])
res_str = self.tokenizer.decode(res_tokens)
res_str = self.decode_tokens(res_tokens)
# print(f"[DEBUG] final output : \n{res_str}")
return res_str
yield res_str
def generate_new_token(self, params, debug=False):
def forward_first(first_vic, prompt, cache_outputs=False):

View File

@@ -103,6 +103,7 @@ def main():
dtype,
args.use_base_vae,
cpu_scheduling,
args.max_embeddings_multiples,
use_stencil=use_stencil,
)
total_time = time.time() - start_time

View File

@@ -81,6 +81,7 @@ def main():
dtype,
args.use_base_vae,
cpu_scheduling,
args.max_embeddings_multiples,
)
total_time = time.time() - start_time
text_output = f"prompt={args.prompts}"

View File

@@ -79,6 +79,7 @@ def main():
dtype,
args.use_base_vae,
cpu_scheduling,
args.max_embeddings_multiples,
)
total_time = time.time() - start_time
text_output = f"prompt={args.prompts}"

View File

@@ -73,6 +73,7 @@ if __name__ == "__main__":
dtype,
args.use_base_vae,
cpu_scheduling,
args.max_embeddings_multiples,
)
total_time = time.time() - start_time
text_output = f"prompt={args.prompts}"

View File

@@ -78,7 +78,7 @@ exe = EXE(
debug=False,
bootloader_ignore_signals=False,
strip=False,
upx=True,
upx=False,
upx_exclude=[],
runtime_tmpdir=None,
console=True,

View File

@@ -520,16 +520,17 @@ class SharkifyStableDiffusionModel:
torch.nn.functional.pad(inputs[2], pad),
inputs[3])
input_mask = [True, True, True, False]
model_name = "unet512" if use_large else "unet"
shark_unet, unet_mlir = compile_through_fx(
unet,
inputs,
extended_model_name=self.model_name["unet"],
extended_model_name=self.model_name[model_name],
is_f16=is_f16,
f16_input_mask=input_mask,
use_tuned=self.use_tuned,
extra_args=get_opt_flags("unet", precision=self.precision),
base_model_id=self.base_model_id,
model_name="unet",
model_name=model_name,
precision=self.precision,
return_mlir=self.return_mlir,
)

View File

@@ -135,6 +135,7 @@ class Image2ImagePipeline(StableDiffusionPipeline):
dtype,
use_base_vae,
cpu_scheduling,
max_embeddings_multiples,
use_stencil,
):
# prompts and negative prompts must be a list.
@@ -156,7 +157,10 @@ class Image2ImagePipeline(StableDiffusionPipeline):
# Get text embeddings with weight emphasis from prompts
text_embeddings = self.encode_prompts_weight(
prompts, neg_prompts, max_length
prompts,
neg_prompts,
max_length,
max_embeddings_multiples=max_embeddings_multiples,
)
# guidance scale as a float32 tensor.

View File

@@ -378,6 +378,7 @@ class InpaintPipeline(StableDiffusionPipeline):
dtype,
use_base_vae,
cpu_scheduling,
max_embeddings_multiples,
):
# prompts and negative prompts must be a list.
if isinstance(prompts, str):
@@ -408,7 +409,10 @@ class InpaintPipeline(StableDiffusionPipeline):
# Get text embeddings with weight emphasis from prompts
text_embeddings = self.encode_prompts_weight(
prompts, neg_prompts, max_length
prompts,
neg_prompts,
max_length,
max_embeddings_multiples=max_embeddings_multiples,
)
# guidance scale as a float32 tensor.

View File

@@ -379,6 +379,7 @@ class OutpaintPipeline(StableDiffusionPipeline):
dtype,
use_base_vae,
cpu_scheduling,
max_embeddings_multiples,
):
# prompts and negative prompts must be a list.
if isinstance(prompts, str):
@@ -409,7 +410,10 @@ class OutpaintPipeline(StableDiffusionPipeline):
# Get text embeddings with weight emphasis from prompts
text_embeddings = self.encode_prompts_weight(
prompts, neg_prompts, max_length
prompts,
neg_prompts,
max_length,
max_embeddings_multiples=max_embeddings_multiples,
)
# guidance scale as a float32 tensor.

View File

@@ -204,6 +204,7 @@ class StencilPipeline(StableDiffusionPipeline):
dtype,
use_base_vae,
cpu_scheduling,
max_embeddings_multiples,
use_stencil,
):
# Control Embedding check & conversion
@@ -230,7 +231,10 @@ class StencilPipeline(StableDiffusionPipeline):
# Get text embeddings with weight emphasis from prompts
text_embeddings = self.encode_prompts_weight(
prompts, neg_prompts, max_length
prompts,
neg_prompts,
max_length,
max_embeddings_multiples=max_embeddings_multiples,
)
# guidance scale as a float32 tensor.

View File

@@ -168,7 +168,10 @@ class UpscalerPipeline(StableDiffusionPipeline):
text_embeddings = torch.from_numpy(text_embeddings).to(dtype)
text_embeddings_numpy = text_embeddings.detach().numpy()
self.status = SD_STATE_IDLE
self.load_unet()
if text_embeddings.shape[1] <= self.model_max_length:
self.load_unet()
else:
self.load_unet_512()
for i, t in tqdm(enumerate(total_timesteps)):
step_start_time = time.time()
latent_model_input = torch.cat([latents] * 2)
@@ -182,15 +185,26 @@ class UpscalerPipeline(StableDiffusionPipeline):
# Profiling Unet.
profile_device = start_profiling(file_path="unet.rdc")
noise_pred = self.unet(
"forward",
(
latent_model_input,
timestep,
text_embeddings_numpy,
noise_level,
),
)
if text_embeddings.shape[1] <= self.model_max_length:
noise_pred = self.unet(
"forward",
(
latent_model_input,
timestep,
text_embeddings_numpy,
noise_level,
),
)
else:
noise_pred = self.unet_512(
"forward",
(
latent_model_input,
timestep,
text_embeddings_numpy,
noise_level,
),
)
end_profiling(profile_device)
noise_pred = torch.from_numpy(noise_pred)
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
@@ -219,6 +233,7 @@ class UpscalerPipeline(StableDiffusionPipeline):
if self.ondemand:
self.unload_unet()
self.unload_unet_512()
avg_step_time = step_time_sum / len(total_timesteps)
self.log += f"\nAverage step time: {avg_step_time}ms/it"
@@ -243,6 +258,7 @@ class UpscalerPipeline(StableDiffusionPipeline):
dtype,
use_base_vae,
cpu_scheduling,
max_embeddings_multiples,
):
# prompts and negative prompts must be a list.
if isinstance(prompts, str):
@@ -264,7 +280,10 @@ class UpscalerPipeline(StableDiffusionPipeline):
# Get text embeddings with weight emphasis from prompts
text_embeddings = self.encode_prompts_weight(
prompts, neg_prompts, max_length
prompts,
neg_prompts,
max_length,
max_embeddings_multiples=max_embeddings_multiples,
)
# 4. Preprocess image

View File

@@ -810,8 +810,11 @@ def save_output_img(output_img, img_seed, extra_info={}):
new_entry.update(extra_info)
with open(csv_path, "a", encoding="utf-8") as csv_obj:
csv_mode = "a" if os.path.isfile(csv_path) else "w"
with open(csv_path, csv_mode, encoding="utf-8") as csv_obj:
dictwriter_obj = DictWriter(csv_obj, fieldnames=list(new_entry.keys()))
if csv_mode == "w":
dictwriter_obj.writeheader()
dictwriter_obj.writerow(new_entry)
csv_obj.close()

View File

@@ -30,7 +30,11 @@ def launch_app(address):
width = window.winfo_screenwidth()
height = window.winfo_screenheight()
webview.create_window(
"SHARK AI Studio", url=address, width=width, height=height
"SHARK AI Studio",
url=address,
width=width,
height=height,
text_select=True,
)
webview.start(private_mode=False)

View File

@@ -249,6 +249,7 @@ def img2img_inf(
dtype,
args.use_base_vae,
cpu_scheduling,
args.max_embeddings_multiples,
use_stencil=use_stencil,
)
seeds.append(img_seed)

View File

@@ -204,6 +204,7 @@ def inpaint_inf(
dtype,
args.use_base_vae,
cpu_scheduling,
args.max_embeddings_multiples,
)
seeds.append(img_seed)
total_time = time.time() - start_time

View File

@@ -211,6 +211,7 @@ def outpaint_inf(
dtype,
args.use_base_vae,
cpu_scheduling,
args.max_embeddings_multiples,
)
seeds.append(img_seed)
total_time = time.time() - start_time

View File

@@ -65,16 +65,11 @@ def chat(curr_system_message, history, model, device, precision):
)
prompt = messages.strip()
print("prompt = ", prompt)
sentence = vicuna_model.generate(prompt)
partial_text = ""
for new_text in sentence.split(" "):
# print(new_text)
partial_text += new_text + " "
for partial_text in vicuna_model.generate(prompt):
history[-1][1] = partial_text
# Yield an empty string to cleanup the message textbox and the updated conversation history
yield history
history[-1][1] = sentence
return history
# else Model is StableLM

View File

@@ -202,6 +202,7 @@ def upscaler_inf(
dtype,
args.use_base_vae,
cpu_scheduling,
args.max_embeddings_multiples,
)
if global_obj.get_sd_status() == SD_STATE_CANCEL:
break

View File

@@ -11,21 +11,35 @@ def has_csv(image_filename: str) -> bool:
return os.path.exists(csv_path(image_filename))
def parse_csv(image_filename: str):
# We use a reader instead of a DictReader here for images_details.csv files due to the lack of
# headers, and then match up the return list for each row with our guess at which column format
# the file is using.
def matching_filename(image_filename: str, row):
# we assume the final column of the csv has the original filename with full path and match that
# against the image_filename. We then exclude the filename from the output, hence the -1's.
# against the image_filename if we are given a list. Otherwise we assume a dict and and take
# the value of the OUTPUT key
return os.path.basename(image_filename) in (
row[-1] if isinstance(row, list) else row["OUTPUT"]
)
def parse_csv(image_filename: str):
csv_filename = csv_path(image_filename)
matches = [
humanize(row)
for row in csv.reader(open(csv_filename, "r", newline=""))
if row
and humanizable(row)
and os.path.basename(image_filename) in row[-1]
]
with open(csv_filename, "r", newline="") as csv_file:
# We use a reader or DictReader here for images_details.csv depending on whether we think it
# has headers or not. Having headers means less guessing of the format.
has_header = csv.Sniffer().has_header(csv_file.read(2048))
csv_file.seek(0)
reader = (
csv.DictReader(csv_file) if has_header else csv.reader(csv_file)
)
matches = [
# we rely on humanize and humanizable to work out the parsing of the individual .csv rows
humanize(row)
for row in reader
if row
and (has_header or humanizable(row))
and matching_filename(image_filename, row)
]
return matches[0] if matches else {}

View File

@@ -50,7 +50,22 @@ PARAMS_FORMATS = {
},
}
PARAMS_FORMAT_LONGEST = PARAMS_FORMATS[max(PARAMS_FORMATS.keys())]
PARAMS_FORMAT_CURRENT = {
"VARIANT": "Model",
"VAE": "VAE",
"LORA": "LoRA",
"SCHEDULER": "Sampler",
"PROMPT": "Prompt",
"NEG_PROMPT": "Negative prompt",
"SEED": "Seed",
"CFG_SCALE": "CFG scale",
"PRECISION": "Precision",
"STEPS": "Steps",
"HEIGHT": "Height",
"WIDTH": "Width",
"MAX_LENGTH": "Max Length",
"OUTPUT": "Filename",
}
def compact(metadata: dict) -> dict:
@@ -97,19 +112,20 @@ def humanize(metadata: dict | list[str], includes_filename=True) -> dict:
)
# For dictionaries we try to use the matching length parameter format if
# available, otherwise we use the longest. Then we swap keys in the
# metadata that match keys in the format for the friendlier name that we
# have set in the format value
# available, otherwise we just use the current format which is assumed to
# have everything currently known about. Then we swap keys in the metadata
# that match keys in the format for the friendlier name that we have set
# in the format value
if isinstance(metadata, dict):
if humanizable(metadata, includes_filename):
format = PARAMS_FORMATS[lookup_key]
else:
format = PARAMS_FORMAT_LONGEST
format = PARAMS_FORMAT_CURRENT
return {
format[key]: value
for (key, value) in metadata.items()
if key in format.keys()
format[key]: metadata[key]
for key in format.keys()
if key in metadata.keys() and metadata[key]
}
raise TypeError("Can only humanize parameter lists or dictionaries")

View File

@@ -0,0 +1,154 @@
import functools
from typing import List, Optional
import torch
from torch.fx.experimental.proxy_tensor import make_fx
from torch._functorch.compile_utils import strip_overloads
from shark.shark_inference import SharkInference
from torch._decomp import get_decompositions
from torch.func import functionalize
import io
import torch_mlir
# TODO: Control decompositions.
def default_decompositions():
return get_decompositions(
[
torch.ops.aten.embedding_dense_backward,
torch.ops.aten.native_layer_norm_backward,
torch.ops.aten.slice_backward,
torch.ops.aten.select_backward,
torch.ops.aten.norm.ScalarOpt_dim,
torch.ops.aten.native_group_norm,
torch.ops.aten.upsample_bilinear2d.vec,
torch.ops.aten.split.Tensor,
torch.ops.aten.split_with_sizes,
torch.ops.aten.native_layer_norm,
torch.ops.aten.masked_fill.Tensor,
torch.ops.aten.masked_fill.Scalar,
]
)
def _remove_nones(fx_g: torch.fx.GraphModule) -> List[int]:
removed_indexes = []
for node in fx_g.graph.nodes:
if node.op == "output":
assert (
len(node.args) == 1
), "Output node must have a single argument"
node_arg = node.args[0]
if isinstance(node_arg, (list, tuple)):
node_arg = list(node_arg)
node_args_len = len(node_arg)
for i in range(node_args_len):
curr_index = node_args_len - (i + 1)
if node_arg[curr_index] is None:
removed_indexes.append(curr_index)
node_arg.pop(curr_index)
node.args = (tuple(node_arg),)
break
if len(removed_indexes) > 0:
fx_g.graph.lint()
fx_g.graph.eliminate_dead_code()
fx_g.recompile()
removed_indexes.sort()
return removed_indexes
def _returns_nothing(fx_g: torch.fx.GraphModule) -> bool:
for node in fx_g.graph.nodes:
if node.op == "output":
assert (
len(node.args) == 1
), "Output node must have a single argument"
node_arg = node.args[0]
if isinstance(node_arg, tuple):
return len(node_arg) == 0
return False
def _unwrap_single_tuple_return(fx_g: torch.fx.GraphModule) -> bool:
"""
Replace tuple with tuple element in functions that return one-element tuples.
Returns true if an unwrapping took place, and false otherwise.
"""
unwrapped_tuple = False
for node in fx_g.graph.nodes:
if node.op == "output":
assert (
len(node.args) == 1
), "Output node must have a single argument"
node_arg = node.args[0]
if isinstance(node_arg, tuple):
if len(node_arg) == 1:
node.args = (node_arg[0],)
unwrapped_tuple = True
break
if unwrapped_tuple:
fx_g.graph.lint()
fx_g.recompile()
return unwrapped_tuple
class SharkBackend:
def __init__(
self, fx_g: torch.fx.GraphModule, inputs: tuple, options: dict
):
self.fx_g = fx_g
self.inputs = inputs
self.shark_module = None
self.device: str = options.get("device", "cpu")
self.was_unwrapped: bool = False
self.none_indices: list = []
self._modify_fx_g()
self.compile()
def _modify_fx_g(self):
self.none_indices = _remove_nones(self.fx_g)
self.was_unwrapped = _unwrap_single_tuple_return(self.fx_g)
def compile(self):
gm = make_fx(
functionalize(self.fx_g),
decomposition_table=default_decompositions(),
)(*self.inputs)
gm.graph.set_codegen(torch.fx.graph.CodeGen())
gm.recompile()
strip_overloads(gm)
ts_g = torch.jit.script(gm)
mlir_module = torch_mlir.compile(
ts_g, self.inputs, output_type="linalg-on-tensors"
)
bytecode_stream = io.BytesIO()
mlir_module.operation.write_bytecode(bytecode_stream)
bytecode = bytecode_stream.getvalue()
from shark.shark_inference import SharkInference
shark_module = SharkInference(
mlir_module=bytecode,
device=self.device,
mlir_dialect="tm_tensor",
)
shark_module.compile(extra_args=[])
self.shark_module = shark_module
def __call__(self, *inputs):
np_inputs = [x.contiguous().detach().cpu().numpy() for x in inputs]
np_outs = self.shark_module("forward", np_inputs)
if self.was_unwrapped:
np_outs = [
np_outs,
]
if not isinstance(np_outs, list):
res = torch.from_numpy(np_outs)
return res
result = [torch.from_numpy(x) for x in np_outs]
for r_in in self.none_indices:
result.insert(r_in, None)
result = tuple(result)
return result

View File

@@ -1,10 +1,7 @@
import torch
import torch_mlir
from shark.shark_inference import SharkInference
from apps.stable_diffusion.src.utils import (
compile_through_fx,
args,
)
from shark.shark_compile import shark_compile_through_fx
from MEGABYTE_pytorch import MEGABYTE
import os
@@ -37,23 +34,22 @@ class MegaModel(torch.nn.Module):
megaModel = MegaModel()
input = [torch.randint(0, 16000, (1, 1024, 4))]
inputs = [torch.randint(0, 16000, (1, 1024, 4))]
# CURRENTLY IT BAILS OUT HERE BECAUSE OF MISSING OP LOWERINGS :-
# 1. aten.alias
shark_module, _ = compile_through_fx(
megaModel,
inputs=input,
shark_module, _ = shark_compile_through_fx(
model=megaModel,
inputs=inputs,
extended_model_name="mega_shark",
debug=False,
generate_vmfb=True,
is_f16=False,
f16_input_mask=None,
save_dir=os.getcwd(),
debug=False,
generate_or_load_vmfb=True,
extra_args=[],
base_model_id=None,
model_name="mega_shark",
precision=None,
return_mlir=True,
device="cuda",
mlir_dialect="tm_tensor",
)
# logits = model(x)
@@ -63,10 +59,10 @@ def print_output_info(output, msg):
print("\n\t", output.shape)
ans = shark_module("forward", input)
ans = shark_module("forward", inputs)
print_output_info(torch.from_numpy(ans), "SHARK's output")
ans = megaModel.forward(*input)
ans = megaModel.forward(*inputs)
print_output_info(ans, "ORIGINAL Model's output")
# and sample from the logits accordingly

View File

@@ -14,6 +14,7 @@
import iree.runtime as ireert
import iree.compiler as ireec
from shark.iree_utils._common import iree_device_map, iree_target_map
from shark.iree_utils.cpu_utils import get_iree_cpu_rt_args
from shark.iree_utils.benchmark_utils import *
from shark.parser import shark_args
import numpy as np
@@ -352,6 +353,12 @@ def load_vmfb_using_mmap(
config = ireert.Config(device=haldevice)
else:
config = get_iree_runtime_config(device)
if "task" in device:
print(
f"[DEBUG] setting iree runtime flags for cpu:\n{' '.join(get_iree_cpu_rt_args())}"
)
for flag in get_iree_cpu_rt_args():
ireert.flags.parse_flags(flag)
# Now load vmfb.
# Two scenarios we have here :-
# 1. We either have the vmfb already saved and therefore pass the path of it.
@@ -359,7 +366,6 @@ def load_vmfb_using_mmap(
# OR 2. We are compiling on the fly, therefore we have the flatbuffer blob to play with.
# (This would arise if we're invoking `compile` from a SharkInference obj)
temp_file_to_unlink = None
if isinstance(flatbuffer_blob_or_path, Path):
flatbuffer_blob_or_path = flatbuffer_blob_or_path.__str__()
if (

View File

@@ -16,6 +16,7 @@
import subprocess
import platform
from shark.parser import shark_args
def get_cpu_count():
@@ -44,4 +45,18 @@ def get_iree_cpu_args():
error_message = f"OS Type f{os_name} not supported and triple can't be determined, open issue to dSHARK team please :)"
raise Exception(error_message)
print(f"Target triple found:{target_triple}")
return [f"--iree-llvmcpu-target-triple={target_triple}"]
return [
f"--iree-llvmcpu-target-triple={target_triple}",
]
# Get iree runtime flags for cpu
def get_iree_cpu_rt_args():
default = get_cpu_count()
default = default if default <= 8 else default - 2
cpu_count = (
default
if shark_args.task_topology_max_group_count is None
else shark_args.task_topology_max_group_count
)
return [f"--task_topology_max_group_count={cpu_count}"]

View File

@@ -119,5 +119,11 @@ parser.add_argument(
"to augment the base device allocator",
choices=["debug", "caching"],
)
parser.add_argument(
"--task_topology_max_group_count",
type=str,
default=None,
help="passthrough flag for the iree flag of the same name. If None, defaults to cpu-count",
)
shark_args, unknown = parser.parse_known_args()

99
shark/shark_compile.py Normal file
View File

@@ -0,0 +1,99 @@
import os
import tempfile
from shark.shark_inference import SharkInference
from shark.shark_importer import import_with_fx
def load_vmfb(extended_model_name, device, mlir_dialect, extra_args=[]):
vmfb_path = os.path.join(os.getcwd(), extended_model_name + ".vmfb")
shark_module = None
if os.path.isfile(vmfb_path):
shark_module = SharkInference(
None,
device=device,
mlir_dialect=mlir_dialect,
)
print(f"loading existing vmfb from: {vmfb_path}")
shark_module.load_module(vmfb_path, extra_args=extra_args)
return shark_module
def compile_module(
shark_module, extended_model_name, generate_vmfb, extra_args=[]
):
if generate_vmfb:
vmfb_path = os.path.join(os.getcwd(), extended_model_name + ".vmfb")
if os.path.isfile(vmfb_path):
print(f"loading existing vmfb from: {vmfb_path}")
shark_module.load_module(vmfb_path, extra_args=extra_args)
else:
print(
"No vmfb found. Compiling and saving to {}".format(vmfb_path)
)
path = shark_module.save_module(
os.getcwd(), extended_model_name, extra_args
)
shark_module.load_module(path, extra_args=extra_args)
else:
shark_module.compile(extra_args)
return shark_module
def shark_compile_through_fx(
model,
inputs,
extended_model_name,
is_f16=False,
f16_input_mask=None,
save_dir=tempfile.gettempdir(),
debug=False,
generate_or_load_vmfb=True,
extra_args=[],
device=None,
mlir_dialect="tm_tensor",
):
if generate_or_load_vmfb:
shark_module = load_vmfb(
extended_model_name=extended_model_name,
device=device,
mlir_dialect=mlir_dialect,
extra_args=extra_args,
)
if shark_module:
return (
shark_module,
None,
)
from shark.parser import shark_args
if "cuda" in device:
shark_args.enable_tf32 = True
(
mlir_module,
_,
) = import_with_fx(
model=model,
inputs=inputs,
is_f16=is_f16,
f16_input_mask=f16_input_mask,
debug=debug,
model_name=extended_model_name,
save_dir=save_dir,
)
shark_module = SharkInference(
mlir_module,
device=device,
mlir_dialect=mlir_dialect,
)
return (
compile_module(
shark_module,
extended_model_name,
generate_vmfb=generate_or_load_vmfb,
extra_args=extra_args,
),
mlir_module,
)

View File

@@ -1,11 +0,0 @@
1. Install torchdynamo
- `git clone https://github.com/pytorch/torchdynamo.git`
- `cd torchdynamo`
- `python -m pip install -r requirements.txt`
- `python setup.py develop`
2. Install functorch
- `python -m pip install -v "git+https://github.com/pytorch/pytorch.git@$(python -c "import torch.version; print(torch.version.git_version)")#subdirectory=functorch"`
3. Run examples.
- `python shark/examples/shark_dynamo/basic_examples.py`

View File

@@ -1,163 +0,0 @@
import functools
import time
from typing import List, Optional
import torch
from torch.fx.experimental.proxy_tensor import make_fx
from torch._functorch.compile_utils import strip_overloads
from shark.shark_inference import SharkInference
from torch._decomp import get_decompositions
import torch_mlir
# TODO: Control decompositions.
def default_decompositions():
return get_decompositions(
[
torch.ops.aten.embedding_dense_backward,
torch.ops.aten.native_layer_norm_backward,
torch.ops.aten.slice_backward,
torch.ops.aten.select_backward,
torch.ops.aten.norm.ScalarOpt_dim,
torch.ops.aten.native_group_norm,
torch.ops.aten.upsample_bilinear2d.vec,
torch.ops.aten.split.Tensor,
torch.ops.aten.split_with_sizes,
]
)
def timeit(*, append_time_to: Optional[List] = None):
def decorator(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
start_time = time.time_ns()
result = func(*args, **kwargs)
end_time = time.time_ns()
if append_time_to is not None:
append_time_to.append(end_time - start_time)
return result
return wrapper
return decorator
def _returns_nothing(fx_g: torch.fx.GraphModule) -> bool:
for node in fx_g.graph.nodes:
if node.op == "output":
assert (
len(node.args) == 1
), "Output node must have a single argument"
node_arg = node.args[0]
if isinstance(node_arg, tuple):
return len(node_arg) == 0
return False
def _unwrap_single_tuple_return(fx_g: torch.fx.GraphModule) -> bool:
"""
Replace tuple with tuple element in functions that return one-element tuples.
Returns true if an unwrapping took place, and false otherwise.
"""
unwrapped_tuple = False
for node in fx_g.graph.nodes:
if node.op == "output":
assert (
len(node.args) == 1
), "Output node must have a single argument"
node_arg = node.args[0]
if isinstance(node_arg, tuple):
if len(node_arg) == 1:
node.args = (node_arg[0],)
unwrapped_tuple = True
break
if unwrapped_tuple:
fx_g.graph.lint()
fx_g.recompile()
return unwrapped_tuple
def make_shark_compiler(use_tracing: bool, device: str, verbose=False):
def compiler(
fx_graph: torch.fx.GraphModule,
example_inputs: List[torch.Tensor],
):
"""Compile GraphModule using torch-mlir + SHARK."""
if verbose:
print("Compiling graph...")
if _returns_nothing(fx_graph):
return fx_graph
was_unwrapped = _unwrap_single_tuple_return(fx_graph)
fx_graph = make_fx(
fx_graph, decomposition_table=default_decompositions()
)(*example_inputs)
strip_overloads(fx_graph)
if verbose:
print("torch.fx graph:")
print(fx_graph.graph)
ts_compiler = torch.jit.trace if use_tracing else torch.jit.script
ts_graph = ts_compiler(fx_graph, example_inputs)
if verbose:
torch_mlir_module = torch_mlir.compile(
ts_graph,
example_inputs,
output_type=torch_mlir.OutputType.TORCH,
)
print("\n\ntorch-mlir backend contract graph:")
print(torch_mlir_module)
linalg_module = torch_mlir.compile(
ts_graph,
example_inputs,
output_type=torch_mlir.OutputType.LINALG_ON_TENSORS,
)
import io
bytecode_stream = io.BytesIO()
linalg_module.operation.write_bytecode(bytecode_stream)
mlir_module = bytecode_stream.getvalue()
shark_module = SharkInference(
mlir_module, mlir_dialect="linalg", device=device
)
shark_module.compile()
def forward(*inputs):
result = shark_module("forward", inputs)
result = tuple() if result is None else result
return (result,) if was_unwrapped else result
return forward
return compiler
def check_results(compiled_results, eager_results):
for compiled_result, eager_result in zip(compiled_results, eager_results):
if not torch.allclose(
compiled_result.to("cpu"), eager_result.to("cpu"), atol=1e-5
):
print("Compiled result does not match eager result")
return
print("Compiled result matches eager result!")
def print_time_stats(times):
times_tensor = torch.tensor(times)
def quantile_ms(q):
return torch.quantile(times_tensor.to(float), q).item() / 1e6
print(f"Median: {quantile_ms(0.5)} ms")
print(f"10%ile: {quantile_ms(0.1)} ms")
print(f"90%ile: {quantile_ms(0.9)} ms")
print(f"Total: {torch.sum(times_tensor) / 1e6} ms")
print()