mirror of
https://github.com/nod-ai/AMD-SHARK-Studio.git
synced 2026-04-03 03:00:17 -04:00
Move clean_device_info to compile_utils
This commit is contained in:
@@ -6,6 +6,7 @@ from transformers import (
|
||||
AutoModelForCausalLM,
|
||||
)
|
||||
from apps.stable_diffusion.web.ui.utils import available_devices
|
||||
from shark.iree_utils.compile_utils import clean_device_info
|
||||
from datetime import datetime as dt
|
||||
import json
|
||||
import sys
|
||||
@@ -131,28 +132,6 @@ def get_default_config():
|
||||
c = GenerateConfigFile(model, 1, ["gpu_id"], firstVicunaCompileInput)
|
||||
c.split_into_layers()
|
||||
|
||||
|
||||
def clean_device_info(raw_device):
|
||||
# return appropriate device and device_id for consumption by LLM pipeline
|
||||
# Multiple devices only supported for vulkan and rocm (as of now).
|
||||
# default device must be selected for all others
|
||||
|
||||
device_id = None
|
||||
device = (
|
||||
raw_device
|
||||
if "=>" not in raw_device
|
||||
else raw_device.split("=>")[1].strip()
|
||||
)
|
||||
if "://" in device:
|
||||
device, device_id = device.split("://")
|
||||
device_id = int(device_id) # using device index in webui
|
||||
|
||||
if device not in ["rocm", "vulkan"]:
|
||||
device_id = None
|
||||
|
||||
return device, device_id
|
||||
|
||||
|
||||
model_vmfb_key = ""
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user