mirror of
https://github.com/nod-ai/AMD-SHARK-Studio.git
synced 2026-04-03 03:00:17 -04:00
swap to cpu an remove hardcoded paths (#1448)
Co-authored-by: powderluv <powderluv@users.noreply.github.com>
This commit is contained in:
@@ -2,8 +2,6 @@ import sys
|
||||
import warnings
|
||||
|
||||
warnings.filterwarnings("ignore")
|
||||
sys.path.insert(0, "D:\S\SB\I\python_packages\iree_compiler")
|
||||
sys.path.insert(0, "D:\S\SB\I\python_packages\iree_runtime")
|
||||
import torch
|
||||
import torch_mlir
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||
@@ -20,6 +18,9 @@ from tqdm import tqdm
|
||||
from torch_mlir import TensorPlaceholder
|
||||
|
||||
|
||||
import argparse
|
||||
|
||||
|
||||
class FirstVicunaLayer(torch.nn.Module):
|
||||
def __init__(self, model):
|
||||
super().__init__()
|
||||
@@ -545,7 +546,7 @@ def compile_to_vmfb(inputs, layers, is_first=True):
|
||||
else:
|
||||
vmfb_path = Path(f"{idx}_1.vmfb")
|
||||
if idx < 25:
|
||||
device = "vulkan"
|
||||
device = "cpu"
|
||||
else:
|
||||
device = "cpu"
|
||||
if vmfb_path.exists():
|
||||
|
||||
Reference in New Issue
Block a user