Compare commits

..

4 Commits

Author SHA1 Message Date
Ean Garvey
326327a799 Collect pipeline submodules for diffusers ckpt preprocessing. (#1859) 2023-10-03 00:29:28 -04:00
Ean Garvey
785b65c7b0 Add flag for specifying device-local caching allocator heap key. (#1856) 2023-10-03 00:28:39 -04:00
Sungsoon Cho
0d16c81687 Remove unused import. (#1857) 2023-10-02 11:36:08 -05:00
Vivek Khandelwal
8dd7850c69 Add Falcon-GPTQ support 2023-10-02 16:39:57 +05:30
6 changed files with 104 additions and 8 deletions

View File

@@ -118,6 +118,7 @@ class Falcon(SharkLLMBase):
"torch_dtype": torch.float,
"trust_remote_code": True,
"token": self.hf_auth_token,
"device_map": "cpu" if args.device == "cpu" else "cuda:0",
}
falcon_model = AutoModelForCausalLM.from_pretrained(
self.hf_model_path, **kwargs
@@ -198,6 +199,7 @@ class Falcon(SharkLLMBase):
is_f16=self.precision == "fp16",
f16_input_mask=[False, False],
mlir_type="torchscript",
is_gptq=self.precision == "int4",
)
del model
print(f"[DEBUG] generating torch mlir")
@@ -488,12 +490,22 @@ if __name__ == "__main__":
else Path(args.falcon_vmfb_path)
)
if args.falcon_variant_to_use == "180b":
hf_model_path_value = "tiiuae/falcon-180B-chat"
if args.precision == "int4":
if args.falcon_variant_to_use == "180b":
hf_model_path_value = "TheBloke/Falcon-180B-Chat-GPTQ"
else:
hf_model_path_value = (
"TheBloke/falcon-"
+ args.falcon_variant_to_use
+ "-instruct-GPTQ"
)
else:
hf_model_path_value = (
"tiiuae/falcon-" + args.falcon_variant_to_use + "-instruct"
)
if args.falcon_variant_to_use == "180b":
hf_model_path_value = "tiiuae/falcon-180B-chat"
else:
hf_model_path_value = (
"tiiuae/falcon-" + args.falcon_variant_to_use + "-instruct"
)
falcon = Falcon(
model_name="falcon_" + args.falcon_variant_to_use,
@@ -524,7 +536,11 @@ if __name__ == "__main__":
prompt = input("Please enter the prompt text: ")
print("\nPrompt Text: ", prompt)
res_str = falcon.generate(prompt)
prompt_template = f"""A helpful assistant who helps the user with any questions asked.
User: {prompt}
Assistant:"""
res_str = falcon.generate(prompt_template)
torch.cuda.empty_cache()
gc.collect()
print(

View File

@@ -74,6 +74,9 @@ datas += [
# hidden imports for pyinstaller
hiddenimports = ["shark", "shark.shark_inference", "apps"]
hiddenimports += [x for x in collect_submodules("skimage") if "tests" not in x]
hiddenimports += [
x for x in collect_submodules("diffusers") if "tests" not in x
]
blacklist = ["tests", "convert"]
hiddenimports += [
x

View File

@@ -458,6 +458,14 @@ p.add_argument(
help="Specify your own huggingface authentication tokens for models like Llama2.",
)
p.add_argument(
"--device_allocator_heap_key",
type=str,
default="",
help="Specify heap key for device caching allocator."
"Expected form: max_allocation_size;max_allocation_capacity;max_free_allocation_count"
"Example: --device_allocator_heap_key='*;1gib' (will limit caching on device to 1 gigabyte)",
)
##############################################################################
# IREE - Vulkan supported flags
##############################################################################

View File

@@ -184,12 +184,18 @@ def compile_through_fx(
def set_iree_runtime_flags():
# TODO: This function should be device-agnostic and piped properly
# to general runtime driver init.
vulkan_runtime_flags = get_iree_vulkan_runtime_flags()
if args.enable_rgp:
vulkan_runtime_flags += [
f"--enable_rgp=true",
f"--vulkan_debug_utils=true",
]
if args.device_allocator_heap_key:
vulkan_runtime_flags += [
f"--device_allocator=caching:device_local={args.device_allocator_heap_key}",
]
set_iree_vulkan_runtime_flags(flags=vulkan_runtime_flags)

View File

@@ -451,6 +451,65 @@ def transform_fx(fx_g, quantized=False):
fx_g.graph.lint()
def gptq_transforms(fx_g):
import torch
for node in fx_g.graph.nodes:
if node.op == "call_function":
if node.target in [
torch.ops.aten.arange,
torch.ops.aten.empty,
torch.ops.aten.ones,
torch.ops.aten._to_copy,
]:
if node.kwargs.get("device") == torch.device(device="cuda:0"):
updated_kwargs = node.kwargs.copy()
updated_kwargs["device"] = torch.device(device="cpu")
node.kwargs = updated_kwargs
if node.target in [
torch.ops.aten._to_copy,
]:
if node.kwargs.get("dtype") == torch.bfloat16:
updated_kwargs = node.kwargs.copy()
updated_kwargs["dtype"] = torch.float16
node.kwargs = updated_kwargs
# Inputs of aten.native_layer_norm should be upcasted to fp32.
if node.target in [torch.ops.aten.native_layer_norm]:
with fx_g.graph.inserting_before(node):
new_node_arg0 = fx_g.graph.call_function(
torch.ops.prims.convert_element_type,
args=(node.args[0], torch.float32),
kwargs={},
)
node.args = (
new_node_arg0,
node.args[1],
node.args[2],
node.args[3],
node.args[4],
)
# Downcasting the result of native_layer_norm back to fp16.
if node.name.startswith("getitem"):
with fx_g.graph.inserting_before(node):
if node.args[0].target in [
torch.ops.aten.native_layer_norm
]:
new_node = fx_g.graph.call_function(
torch.ops.aten._to_copy,
args=(node,),
kwargs={"dtype": torch.float32},
)
node.append(new_node)
node.replace_all_uses_with(new_node)
new_node.args = (node,)
new_node.kwargs = {"dtype": torch.float32}
fx_g.graph.lint()
# Doesn't replace the None type.
def change_fx_graph_return_to_tuple(fx_g):
for node in fx_g.graph.nodes:
@@ -504,6 +563,7 @@ def import_with_fx(
is_dynamic=False,
tracing_required=False,
precision="fp32",
is_gptq=False,
):
import torch
from torch.fx.experimental.proxy_tensor import make_fx
@@ -584,7 +644,7 @@ def import_with_fx(
torch.ops.aten.index_add,
torch.ops.aten.index_add_,
]
if precision in ["int4", "int8"]:
if precision in ["int4", "int8"] and not is_gptq:
from brevitas_examples.llm.llm_quant.export import (
block_quant_layer_level_manager,
)
@@ -653,6 +713,10 @@ def import_with_fx(
add_upcast(fx_g)
fx_g.recompile()
if is_gptq:
gptq_transforms(fx_g)
fx_g.recompile()
if mlir_type == "fx":
return fx_g

View File

@@ -18,7 +18,6 @@ import collections
import json
import os
import psutil
import resource
import time
from typing import Tuple