mirror of
https://github.com/nod-ai/AMD-SHARK-Studio.git
synced 2026-02-19 11:56:43 -05:00
* Shark Studio SDXL support, HIP driver support, simpler device info, small fixes * Fixups to llm API/UI and ignore user config files. * Small fixes for unifying pipelines. * Update requirements.txt for iree-turbine (#2130) * Fix Llama2 on CPU (#2133) * Filesystem cleanup and custom model fixes (#2127) * Fix some formatting issues * Remove IREE pin (fixes exe issue) (#2126) * Update find links for IREE packages (#2136) * Shark Studio SDXL support, HIP driver support, simpler device info, small fixes * Abstract out SD pipelines from Studio Webui (WIP) * Switch from pin to minimum torch version and fix index url * Fix device parsing. * Fix linux setup * Fix custom weights. --------- Co-authored-by: saienduri <77521230+saienduri@users.noreply.github.com> Co-authored-by: gpetters-amd <159576198+gpetters-amd@users.noreply.github.com> Co-authored-by: gpetters94 <gpetters@protonmail.com>
59 lines
1.8 KiB
Python
59 lines
1.8 KiB
Python
# Copyright 2023 Nod Labs, Inc
|
|
#
|
|
# Licensed under the Apache License v2.0 with LLVM Exceptions.
|
|
# See https://llvm.org/LICENSE.txt for license information.
|
|
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
|
|
import logging
|
|
import unittest
|
|
import json
|
|
import gc
|
|
from apps.shark_studio.api.llm import LanguageModel, llm_chat_api
|
|
from apps.shark_studio.api.sd import shark_sd_fn_dict_input, view_json_file
|
|
from apps.shark_studio.web.utils.file_utils import get_resource_path
|
|
|
|
# class SDAPITest(unittest.TestCase):
|
|
# def testSDSimple(self):
|
|
# from apps.shark_studio.modules.shared_cmd_opts import cmd_opts
|
|
# import apps.shark_studio.web.utils.globals as global_obj
|
|
|
|
# global_obj._init()
|
|
|
|
# sd_json = view_json_file(get_resource_path("../configs/default_sd_config.json"))
|
|
# sd_kwargs = json.loads(sd_json)
|
|
# for arg in vars(cmd_opts):
|
|
# if arg in sd_kwargs:
|
|
# sd_kwargs[arg] = getattr(cmd_opts, arg)
|
|
# for i in shark_sd_fn_dict_input(sd_kwargs):
|
|
# print(i)
|
|
|
|
|
|
class LLMAPITest(unittest.TestCase):
|
|
def test01_LLMSmall(self):
|
|
lm = LanguageModel(
|
|
"TinyPixel/small-llama2",
|
|
hf_auth_token=None,
|
|
device="cpu",
|
|
precision="fp32",
|
|
quantization="None",
|
|
streaming_llm=True,
|
|
)
|
|
count = 0
|
|
label = "Turkishoure Turkish"
|
|
for msg, _ in lm.chat("hi, what are you?"):
|
|
# skip first token output
|
|
if count == 0:
|
|
count += 1
|
|
continue
|
|
assert (
|
|
msg.strip(" ") == label
|
|
), f"LLM API failed to return correct response, expected '{label}', received {msg}"
|
|
break
|
|
del lm
|
|
gc.collect()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
logging.basicConfig(level=logging.DEBUG)
|
|
unittest.main()
|