mirror of
https://github.com/nod-ai/AMD-SHARK-Studio.git
synced 2026-02-19 11:56:43 -05:00
* [MiniGPT4] Add MiniGPT4 to SHARK -- This is the first installment of MiniGPT4 in SHARK. Signed-off-by: Abhishek Varma <abhishek@nod-labs.com> * Add int8 support for MiniGPT4 -- This commit adds int8 support for MiniGPT4. Signed-off-by: Abhishek Varma <abhishek@nod-lab.com> * Update .spec for MiniGPT4's config files * black format MiniGPT4 --------- Signed-off-by: Abhishek Varma <abhishek@nod-labs.com> Signed-off-by: Abhishek Varma <abhishek@nod-lab.com>
41 lines
1.2 KiB
Python
41 lines
1.2 KiB
Python
import torch
|
|
from torch.fx.experimental.proxy_tensor import make_fx
|
|
from torch._decomp import get_decompositions
|
|
from typing import List
|
|
from pathlib import Path
|
|
from shark.shark_downloader import download_public_file
|
|
|
|
|
|
# expects a Path / str as arg
|
|
# returns None if path not found or SharkInference module
|
|
def get_vmfb_from_path(vmfb_path, device, mlir_dialect):
|
|
if not isinstance(vmfb_path, Path):
|
|
vmfb_path = Path(vmfb_path)
|
|
|
|
from shark.shark_inference import SharkInference
|
|
|
|
if not vmfb_path.exists():
|
|
return None
|
|
|
|
print("Loading vmfb from: ", vmfb_path)
|
|
print("Device from get_vmfb_from_path - ", device)
|
|
shark_module = SharkInference(
|
|
None, device=device, mlir_dialect=mlir_dialect
|
|
)
|
|
shark_module.load_module(vmfb_path)
|
|
print("Successfully loaded vmfb")
|
|
return shark_module
|
|
|
|
|
|
def get_vmfb_from_config(
|
|
shark_container, model, precision, device, vmfb_path, padding=None
|
|
):
|
|
vmfb_url = (
|
|
f"gs://shark_tank/{shark_container}/{model}_{precision}_{device}"
|
|
)
|
|
if padding:
|
|
vmfb_url = vmfb_url + f"_{padding}"
|
|
vmfb_url = vmfb_url + ".vmfb"
|
|
download_public_file(vmfb_url, vmfb_path.absolute(), single_file=True)
|
|
return get_vmfb_from_path(vmfb_path, device, "tm_tensor")
|