mirror of
https://github.com/nod-ai/AMD-SHARK-Studio.git
synced 2026-02-19 11:56:43 -05:00
single endpoint in apps/language/models/scripts/vicuna.py removed main functions from pipelines replaced divergent utils compile with shark_importer adds support for different precisions
26 lines
734 B
Python
26 lines
734 B
Python
import torch
|
|
from torch.fx.experimental.proxy_tensor import make_fx
|
|
from torch._decomp import get_decompositions
|
|
from typing import List
|
|
from pathlib import Path
|
|
|
|
|
|
# expects a Path / str as arg
|
|
# returns None if path not found or SharkInference module
|
|
def get_vmfb_from_path(vmfb_path, device, mlir_dialect):
|
|
if not isinstance(vmfb_path, Path):
|
|
vmfb_path = Path(vmfb_path)
|
|
|
|
from shark.shark_inference import SharkInference
|
|
|
|
if not vmfb_path.exists():
|
|
return None
|
|
|
|
print("Loading vmfb from: ", vmfb_path)
|
|
shark_module = SharkInference(
|
|
None, device=device, mlir_dialect=mlir_dialect
|
|
)
|
|
shark_module.load_module(vmfb_path)
|
|
print("Successfully loaded vmfb")
|
|
return shark_module
|