mirror of
https://github.com/nod-ai/SHARK-Studio.git
synced 2026-04-20 03:00:34 -04:00
Compare commits
8 Commits
20230623.7
...
20230626.7
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
75672c0e28 | ||
|
|
74a7202173 | ||
|
|
27a08735db | ||
|
|
eaa49cce17 | ||
|
|
10657d6fb1 | ||
|
|
e3ab844cd1 | ||
|
|
5ce6001b41 | ||
|
|
501d0ca52e |
@@ -78,7 +78,7 @@ exe = EXE(
|
||||
debug=False,
|
||||
bootloader_ignore_signals=False,
|
||||
strip=False,
|
||||
upx=True,
|
||||
upx=False,
|
||||
upx_exclude=[],
|
||||
runtime_tmpdir=None,
|
||||
console=True,
|
||||
|
||||
@@ -810,8 +810,11 @@ def save_output_img(output_img, img_seed, extra_info={}):
|
||||
|
||||
new_entry.update(extra_info)
|
||||
|
||||
with open(csv_path, "a", encoding="utf-8") as csv_obj:
|
||||
csv_mode = "a" if os.path.isfile(csv_path) else "w"
|
||||
with open(csv_path, csv_mode, encoding="utf-8") as csv_obj:
|
||||
dictwriter_obj = DictWriter(csv_obj, fieldnames=list(new_entry.keys()))
|
||||
if csv_mode == "w":
|
||||
dictwriter_obj.writeheader()
|
||||
dictwriter_obj.writerow(new_entry)
|
||||
csv_obj.close()
|
||||
|
||||
|
||||
@@ -7,7 +7,7 @@ if sys.platform == "darwin":
|
||||
import torch_mlir
|
||||
|
||||
import shutil
|
||||
import PIL, transformers # ensures inclusion in pysintaller exe generation
|
||||
import PIL, sentencepiece, transformers # ensures inclusion in pysintaller exe generation
|
||||
from apps.stable_diffusion.src import args, clear_all
|
||||
import apps.stable_diffusion.web.utils.global_obj as global_obj
|
||||
|
||||
@@ -30,7 +30,11 @@ def launch_app(address):
|
||||
width = window.winfo_screenwidth()
|
||||
height = window.winfo_screenheight()
|
||||
webview.create_window(
|
||||
"SHARK AI Studio", url=address, width=width, height=height
|
||||
"SHARK AI Studio",
|
||||
url=address,
|
||||
width=width,
|
||||
height=height,
|
||||
text_select=True,
|
||||
)
|
||||
webview.start(private_mode=False)
|
||||
|
||||
|
||||
@@ -136,12 +136,12 @@ with gr.Blocks(title="Chatbot") as stablelm_chat:
|
||||
)
|
||||
precision = gr.Radio(
|
||||
label="Precision",
|
||||
value="fp32",
|
||||
value="fp16",
|
||||
choices=[
|
||||
"fp16",
|
||||
"fp32",
|
||||
"int4",
|
||||
"int8",
|
||||
"fp16",
|
||||
"fp32",
|
||||
],
|
||||
visible=True,
|
||||
)
|
||||
|
||||
@@ -11,21 +11,35 @@ def has_csv(image_filename: str) -> bool:
|
||||
return os.path.exists(csv_path(image_filename))
|
||||
|
||||
|
||||
def parse_csv(image_filename: str):
|
||||
# We use a reader instead of a DictReader here for images_details.csv files due to the lack of
|
||||
# headers, and then match up the return list for each row with our guess at which column format
|
||||
# the file is using.
|
||||
|
||||
def matching_filename(image_filename: str, row):
|
||||
# we assume the final column of the csv has the original filename with full path and match that
|
||||
# against the image_filename. We then exclude the filename from the output, hence the -1's.
|
||||
# against the image_filename if we are given a list. Otherwise we assume a dict and and take
|
||||
# the value of the OUTPUT key
|
||||
return os.path.basename(image_filename) in (
|
||||
row[-1] if isinstance(row, list) else row["OUTPUT"]
|
||||
)
|
||||
|
||||
|
||||
def parse_csv(image_filename: str):
|
||||
csv_filename = csv_path(image_filename)
|
||||
|
||||
matches = [
|
||||
humanize(row)
|
||||
for row in csv.reader(open(csv_filename, "r", newline=""))
|
||||
if row
|
||||
and humanizable(row)
|
||||
and os.path.basename(image_filename) in row[-1]
|
||||
]
|
||||
with open(csv_filename, "r", newline="") as csv_file:
|
||||
# We use a reader or DictReader here for images_details.csv depending on whether we think it
|
||||
# has headers or not. Having headers means less guessing of the format.
|
||||
has_header = csv.Sniffer().has_header(csv_file.read(2048))
|
||||
csv_file.seek(0)
|
||||
|
||||
reader = (
|
||||
csv.DictReader(csv_file) if has_header else csv.reader(csv_file)
|
||||
)
|
||||
|
||||
matches = [
|
||||
# we rely on humanize and humanizable to work out the parsing of the individual .csv rows
|
||||
humanize(row)
|
||||
for row in reader
|
||||
if row
|
||||
and (has_header or humanizable(row))
|
||||
and matching_filename(image_filename, row)
|
||||
]
|
||||
|
||||
return matches[0] if matches else {}
|
||||
|
||||
@@ -50,7 +50,22 @@ PARAMS_FORMATS = {
|
||||
},
|
||||
}
|
||||
|
||||
PARAMS_FORMAT_LONGEST = PARAMS_FORMATS[max(PARAMS_FORMATS.keys())]
|
||||
PARAMS_FORMAT_CURRENT = {
|
||||
"VARIANT": "Model",
|
||||
"VAE": "VAE",
|
||||
"LORA": "LoRA",
|
||||
"SCHEDULER": "Sampler",
|
||||
"PROMPT": "Prompt",
|
||||
"NEG_PROMPT": "Negative prompt",
|
||||
"SEED": "Seed",
|
||||
"CFG_SCALE": "CFG scale",
|
||||
"PRECISION": "Precision",
|
||||
"STEPS": "Steps",
|
||||
"HEIGHT": "Height",
|
||||
"WIDTH": "Width",
|
||||
"MAX_LENGTH": "Max Length",
|
||||
"OUTPUT": "Filename",
|
||||
}
|
||||
|
||||
|
||||
def compact(metadata: dict) -> dict:
|
||||
@@ -97,19 +112,20 @@ def humanize(metadata: dict | list[str], includes_filename=True) -> dict:
|
||||
)
|
||||
|
||||
# For dictionaries we try to use the matching length parameter format if
|
||||
# available, otherwise we use the longest. Then we swap keys in the
|
||||
# metadata that match keys in the format for the friendlier name that we
|
||||
# have set in the format value
|
||||
# available, otherwise we just use the current format which is assumed to
|
||||
# have everything currently known about. Then we swap keys in the metadata
|
||||
# that match keys in the format for the friendlier name that we have set
|
||||
# in the format value
|
||||
if isinstance(metadata, dict):
|
||||
if humanizable(metadata, includes_filename):
|
||||
format = PARAMS_FORMATS[lookup_key]
|
||||
else:
|
||||
format = PARAMS_FORMAT_LONGEST
|
||||
format = PARAMS_FORMAT_CURRENT
|
||||
|
||||
return {
|
||||
format[key]: value
|
||||
for (key, value) in metadata.items()
|
||||
if key in format.keys()
|
||||
format[key]: metadata[key]
|
||||
for key in format.keys()
|
||||
if key in metadata.keys() and metadata[key]
|
||||
}
|
||||
|
||||
raise TypeError("Can only humanize parameter lists or dictionaries")
|
||||
|
||||
154
shark/dynamo_backend/utils.py
Normal file
154
shark/dynamo_backend/utils.py
Normal file
@@ -0,0 +1,154 @@
|
||||
import functools
|
||||
from typing import List, Optional
|
||||
import torch
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
from torch._functorch.compile_utils import strip_overloads
|
||||
from shark.shark_inference import SharkInference
|
||||
from torch._decomp import get_decompositions
|
||||
from torch.func import functionalize
|
||||
import io
|
||||
import torch_mlir
|
||||
|
||||
|
||||
# TODO: Control decompositions.
|
||||
def default_decompositions():
|
||||
return get_decompositions(
|
||||
[
|
||||
torch.ops.aten.embedding_dense_backward,
|
||||
torch.ops.aten.native_layer_norm_backward,
|
||||
torch.ops.aten.slice_backward,
|
||||
torch.ops.aten.select_backward,
|
||||
torch.ops.aten.norm.ScalarOpt_dim,
|
||||
torch.ops.aten.native_group_norm,
|
||||
torch.ops.aten.upsample_bilinear2d.vec,
|
||||
torch.ops.aten.split.Tensor,
|
||||
torch.ops.aten.split_with_sizes,
|
||||
torch.ops.aten.native_layer_norm,
|
||||
torch.ops.aten.masked_fill.Tensor,
|
||||
torch.ops.aten.masked_fill.Scalar,
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def _remove_nones(fx_g: torch.fx.GraphModule) -> List[int]:
|
||||
removed_indexes = []
|
||||
for node in fx_g.graph.nodes:
|
||||
if node.op == "output":
|
||||
assert (
|
||||
len(node.args) == 1
|
||||
), "Output node must have a single argument"
|
||||
node_arg = node.args[0]
|
||||
if isinstance(node_arg, (list, tuple)):
|
||||
node_arg = list(node_arg)
|
||||
node_args_len = len(node_arg)
|
||||
for i in range(node_args_len):
|
||||
curr_index = node_args_len - (i + 1)
|
||||
if node_arg[curr_index] is None:
|
||||
removed_indexes.append(curr_index)
|
||||
node_arg.pop(curr_index)
|
||||
node.args = (tuple(node_arg),)
|
||||
break
|
||||
|
||||
if len(removed_indexes) > 0:
|
||||
fx_g.graph.lint()
|
||||
fx_g.graph.eliminate_dead_code()
|
||||
fx_g.recompile()
|
||||
removed_indexes.sort()
|
||||
return removed_indexes
|
||||
|
||||
|
||||
def _returns_nothing(fx_g: torch.fx.GraphModule) -> bool:
|
||||
for node in fx_g.graph.nodes:
|
||||
if node.op == "output":
|
||||
assert (
|
||||
len(node.args) == 1
|
||||
), "Output node must have a single argument"
|
||||
node_arg = node.args[0]
|
||||
if isinstance(node_arg, tuple):
|
||||
return len(node_arg) == 0
|
||||
return False
|
||||
|
||||
|
||||
def _unwrap_single_tuple_return(fx_g: torch.fx.GraphModule) -> bool:
|
||||
"""
|
||||
Replace tuple with tuple element in functions that return one-element tuples.
|
||||
Returns true if an unwrapping took place, and false otherwise.
|
||||
"""
|
||||
unwrapped_tuple = False
|
||||
for node in fx_g.graph.nodes:
|
||||
if node.op == "output":
|
||||
assert (
|
||||
len(node.args) == 1
|
||||
), "Output node must have a single argument"
|
||||
node_arg = node.args[0]
|
||||
if isinstance(node_arg, tuple):
|
||||
if len(node_arg) == 1:
|
||||
node.args = (node_arg[0],)
|
||||
unwrapped_tuple = True
|
||||
break
|
||||
|
||||
if unwrapped_tuple:
|
||||
fx_g.graph.lint()
|
||||
fx_g.recompile()
|
||||
return unwrapped_tuple
|
||||
|
||||
|
||||
class SharkBackend:
|
||||
def __init__(
|
||||
self, fx_g: torch.fx.GraphModule, inputs: tuple, options: dict
|
||||
):
|
||||
self.fx_g = fx_g
|
||||
self.inputs = inputs
|
||||
self.shark_module = None
|
||||
self.device: str = options.get("device", "cpu")
|
||||
self.was_unwrapped: bool = False
|
||||
self.none_indices: list = []
|
||||
self._modify_fx_g()
|
||||
self.compile()
|
||||
|
||||
def _modify_fx_g(self):
|
||||
self.none_indices = _remove_nones(self.fx_g)
|
||||
self.was_unwrapped = _unwrap_single_tuple_return(self.fx_g)
|
||||
|
||||
def compile(self):
|
||||
gm = make_fx(
|
||||
functionalize(self.fx_g),
|
||||
decomposition_table=default_decompositions(),
|
||||
)(*self.inputs)
|
||||
gm.graph.set_codegen(torch.fx.graph.CodeGen())
|
||||
gm.recompile()
|
||||
strip_overloads(gm)
|
||||
ts_g = torch.jit.script(gm)
|
||||
mlir_module = torch_mlir.compile(
|
||||
ts_g, self.inputs, output_type="linalg-on-tensors"
|
||||
)
|
||||
bytecode_stream = io.BytesIO()
|
||||
mlir_module.operation.write_bytecode(bytecode_stream)
|
||||
bytecode = bytecode_stream.getvalue()
|
||||
from shark.shark_inference import SharkInference
|
||||
|
||||
shark_module = SharkInference(
|
||||
mlir_module=bytecode,
|
||||
device=self.device,
|
||||
mlir_dialect="tm_tensor",
|
||||
)
|
||||
shark_module.compile(extra_args=[])
|
||||
self.shark_module = shark_module
|
||||
|
||||
def __call__(self, *inputs):
|
||||
np_inputs = [x.contiguous().detach().cpu().numpy() for x in inputs]
|
||||
np_outs = self.shark_module("forward", np_inputs)
|
||||
if self.was_unwrapped:
|
||||
np_outs = [
|
||||
np_outs,
|
||||
]
|
||||
|
||||
if not isinstance(np_outs, list):
|
||||
res = torch.from_numpy(np_outs)
|
||||
return res
|
||||
|
||||
result = [torch.from_numpy(x) for x in np_outs]
|
||||
for r_in in self.none_indices:
|
||||
result.insert(r_in, None)
|
||||
result = tuple(result)
|
||||
return result
|
||||
@@ -14,6 +14,7 @@
|
||||
import iree.runtime as ireert
|
||||
import iree.compiler as ireec
|
||||
from shark.iree_utils._common import iree_device_map, iree_target_map
|
||||
from shark.iree_utils.cpu_utils import get_iree_cpu_rt_args
|
||||
from shark.iree_utils.benchmark_utils import *
|
||||
from shark.parser import shark_args
|
||||
import numpy as np
|
||||
@@ -352,6 +353,12 @@ def load_vmfb_using_mmap(
|
||||
config = ireert.Config(device=haldevice)
|
||||
else:
|
||||
config = get_iree_runtime_config(device)
|
||||
if "task" in device:
|
||||
print(
|
||||
f"[DEBUG] setting iree runtime flags for cpu:\n{' '.join(get_iree_cpu_rt_args())}"
|
||||
)
|
||||
for flag in get_iree_cpu_rt_args():
|
||||
ireert.flags.parse_flags(flag)
|
||||
# Now load vmfb.
|
||||
# Two scenarios we have here :-
|
||||
# 1. We either have the vmfb already saved and therefore pass the path of it.
|
||||
@@ -359,7 +366,6 @@ def load_vmfb_using_mmap(
|
||||
# OR 2. We are compiling on the fly, therefore we have the flatbuffer blob to play with.
|
||||
# (This would arise if we're invoking `compile` from a SharkInference obj)
|
||||
temp_file_to_unlink = None
|
||||
|
||||
if isinstance(flatbuffer_blob_or_path, Path):
|
||||
flatbuffer_blob_or_path = flatbuffer_blob_or_path.__str__()
|
||||
if (
|
||||
|
||||
@@ -16,6 +16,7 @@
|
||||
|
||||
import subprocess
|
||||
import platform
|
||||
from shark.parser import shark_args
|
||||
|
||||
|
||||
def get_cpu_count():
|
||||
@@ -44,4 +45,18 @@ def get_iree_cpu_args():
|
||||
error_message = f"OS Type f{os_name} not supported and triple can't be determined, open issue to dSHARK team please :)"
|
||||
raise Exception(error_message)
|
||||
print(f"Target triple found:{target_triple}")
|
||||
return [f"--iree-llvmcpu-target-triple={target_triple}"]
|
||||
return [
|
||||
f"--iree-llvmcpu-target-triple={target_triple}",
|
||||
]
|
||||
|
||||
|
||||
# Get iree runtime flags for cpu
|
||||
def get_iree_cpu_rt_args():
|
||||
default = get_cpu_count()
|
||||
default = default if default <= 8 else default - 2
|
||||
cpu_count = (
|
||||
default
|
||||
if shark_args.task_topology_max_group_count is None
|
||||
else shark_args.task_topology_max_group_count
|
||||
)
|
||||
return [f"--task_topology_max_group_count={cpu_count}"]
|
||||
|
||||
@@ -119,5 +119,11 @@ parser.add_argument(
|
||||
"to augment the base device allocator",
|
||||
choices=["debug", "caching"],
|
||||
)
|
||||
parser.add_argument(
|
||||
"--task_topology_max_group_count",
|
||||
type=str,
|
||||
default=None,
|
||||
help="passthrough flag for the iree flag of the same name. If None, defaults to cpu-count",
|
||||
)
|
||||
|
||||
shark_args, unknown = parser.parse_known_args()
|
||||
|
||||
@@ -1,11 +0,0 @@
|
||||
1. Install torchdynamo
|
||||
- `git clone https://github.com/pytorch/torchdynamo.git`
|
||||
- `cd torchdynamo`
|
||||
- `python -m pip install -r requirements.txt`
|
||||
- `python setup.py develop`
|
||||
|
||||
2. Install functorch
|
||||
- `python -m pip install -v "git+https://github.com/pytorch/pytorch.git@$(python -c "import torch.version; print(torch.version.git_version)")#subdirectory=functorch"`
|
||||
|
||||
3. Run examples.
|
||||
- `python shark/examples/shark_dynamo/basic_examples.py`
|
||||
@@ -1,163 +0,0 @@
|
||||
import functools
|
||||
import time
|
||||
from typing import List, Optional
|
||||
import torch
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
from torch._functorch.compile_utils import strip_overloads
|
||||
from shark.shark_inference import SharkInference
|
||||
from torch._decomp import get_decompositions
|
||||
|
||||
import torch_mlir
|
||||
|
||||
|
||||
# TODO: Control decompositions.
|
||||
def default_decompositions():
|
||||
return get_decompositions(
|
||||
[
|
||||
torch.ops.aten.embedding_dense_backward,
|
||||
torch.ops.aten.native_layer_norm_backward,
|
||||
torch.ops.aten.slice_backward,
|
||||
torch.ops.aten.select_backward,
|
||||
torch.ops.aten.norm.ScalarOpt_dim,
|
||||
torch.ops.aten.native_group_norm,
|
||||
torch.ops.aten.upsample_bilinear2d.vec,
|
||||
torch.ops.aten.split.Tensor,
|
||||
torch.ops.aten.split_with_sizes,
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def timeit(*, append_time_to: Optional[List] = None):
|
||||
def decorator(func):
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
start_time = time.time_ns()
|
||||
result = func(*args, **kwargs)
|
||||
end_time = time.time_ns()
|
||||
|
||||
if append_time_to is not None:
|
||||
append_time_to.append(end_time - start_time)
|
||||
return result
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def _returns_nothing(fx_g: torch.fx.GraphModule) -> bool:
|
||||
for node in fx_g.graph.nodes:
|
||||
if node.op == "output":
|
||||
assert (
|
||||
len(node.args) == 1
|
||||
), "Output node must have a single argument"
|
||||
node_arg = node.args[0]
|
||||
if isinstance(node_arg, tuple):
|
||||
return len(node_arg) == 0
|
||||
return False
|
||||
|
||||
|
||||
def _unwrap_single_tuple_return(fx_g: torch.fx.GraphModule) -> bool:
|
||||
"""
|
||||
Replace tuple with tuple element in functions that return one-element tuples.
|
||||
Returns true if an unwrapping took place, and false otherwise.
|
||||
"""
|
||||
unwrapped_tuple = False
|
||||
for node in fx_g.graph.nodes:
|
||||
if node.op == "output":
|
||||
assert (
|
||||
len(node.args) == 1
|
||||
), "Output node must have a single argument"
|
||||
node_arg = node.args[0]
|
||||
if isinstance(node_arg, tuple):
|
||||
if len(node_arg) == 1:
|
||||
node.args = (node_arg[0],)
|
||||
unwrapped_tuple = True
|
||||
break
|
||||
|
||||
if unwrapped_tuple:
|
||||
fx_g.graph.lint()
|
||||
fx_g.recompile()
|
||||
return unwrapped_tuple
|
||||
|
||||
|
||||
def make_shark_compiler(use_tracing: bool, device: str, verbose=False):
|
||||
def compiler(
|
||||
fx_graph: torch.fx.GraphModule,
|
||||
example_inputs: List[torch.Tensor],
|
||||
):
|
||||
"""Compile GraphModule using torch-mlir + SHARK."""
|
||||
if verbose:
|
||||
print("Compiling graph...")
|
||||
|
||||
if _returns_nothing(fx_graph):
|
||||
return fx_graph
|
||||
|
||||
was_unwrapped = _unwrap_single_tuple_return(fx_graph)
|
||||
fx_graph = make_fx(
|
||||
fx_graph, decomposition_table=default_decompositions()
|
||||
)(*example_inputs)
|
||||
strip_overloads(fx_graph)
|
||||
|
||||
if verbose:
|
||||
print("torch.fx graph:")
|
||||
print(fx_graph.graph)
|
||||
|
||||
ts_compiler = torch.jit.trace if use_tracing else torch.jit.script
|
||||
ts_graph = ts_compiler(fx_graph, example_inputs)
|
||||
|
||||
if verbose:
|
||||
torch_mlir_module = torch_mlir.compile(
|
||||
ts_graph,
|
||||
example_inputs,
|
||||
output_type=torch_mlir.OutputType.TORCH,
|
||||
)
|
||||
print("\n\ntorch-mlir backend contract graph:")
|
||||
print(torch_mlir_module)
|
||||
|
||||
linalg_module = torch_mlir.compile(
|
||||
ts_graph,
|
||||
example_inputs,
|
||||
output_type=torch_mlir.OutputType.LINALG_ON_TENSORS,
|
||||
)
|
||||
import io
|
||||
|
||||
bytecode_stream = io.BytesIO()
|
||||
linalg_module.operation.write_bytecode(bytecode_stream)
|
||||
mlir_module = bytecode_stream.getvalue()
|
||||
|
||||
shark_module = SharkInference(
|
||||
mlir_module, mlir_dialect="linalg", device=device
|
||||
)
|
||||
shark_module.compile()
|
||||
|
||||
def forward(*inputs):
|
||||
result = shark_module("forward", inputs)
|
||||
result = tuple() if result is None else result
|
||||
return (result,) if was_unwrapped else result
|
||||
|
||||
return forward
|
||||
|
||||
return compiler
|
||||
|
||||
|
||||
def check_results(compiled_results, eager_results):
|
||||
for compiled_result, eager_result in zip(compiled_results, eager_results):
|
||||
if not torch.allclose(
|
||||
compiled_result.to("cpu"), eager_result.to("cpu"), atol=1e-5
|
||||
):
|
||||
print("Compiled result does not match eager result")
|
||||
return
|
||||
print("Compiled result matches eager result!")
|
||||
|
||||
|
||||
def print_time_stats(times):
|
||||
times_tensor = torch.tensor(times)
|
||||
|
||||
def quantile_ms(q):
|
||||
return torch.quantile(times_tensor.to(float), q).item() / 1e6
|
||||
|
||||
print(f"Median: {quantile_ms(0.5)} ms")
|
||||
print(f"10%ile: {quantile_ms(0.1)} ms")
|
||||
print(f"90%ile: {quantile_ms(0.9)} ms")
|
||||
print(f"Total: {torch.sum(times_tensor) / 1e6} ms")
|
||||
print()
|
||||
Reference in New Issue
Block a user