Compare commits

...

2 Commits

Author SHA1 Message Date
Aarushi
7cfea2a480 wip 2024-10-08 12:36:45 +01:00
Aarushi
0fbe0a0fa1 enable ollama only when local 2024-10-08 10:48:59 +01:00

View File

@@ -12,6 +12,9 @@ from groq import Groq
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
from backend.data.model import BlockSecret, SchemaField, SecretField
from backend.util import json
from backend.util.settings import AppEnvironment, Settings
settings = Settings()
logger = logging.getLogger(__name__)
@@ -50,14 +53,31 @@ class LlmModel(str, Enum):
LLAMA3_1_405B = "llama-3.1-405b-reasoning"
LLAMA3_1_70B = "llama-3.1-70b-versatile"
LLAMA3_1_8B = "llama-3.1-8b-instant"
@classmethod
def _missing_(cls, value):
if settings.config.app_env == AppEnvironment.LOCAL:
if value == "llama3":
return cls.OLLAMA_LLAMA3_8B
elif value == "llama3.1:405b":
return cls.OLLAMA_LLAMA3_405B
return None
@property
def metadata(self) -> "ModelMetadata":
return MODEL_METADATA[self]
@classmethod
def all_models(cls) -> List["LlmModel"]:
models = list(cls)
if settings.config.app_env == AppEnvironment.LOCAL:
models.extend([cls.OLLAMA_LLAMA3_8B, cls.OLLAMA_LLAMA3_405B])
return models
# Ollama models
OLLAMA_LLAMA3_8B = "llama3"
OLLAMA_LLAMA3_405B = "llama3.1:405b"
@property
def metadata(self) -> ModelMetadata:
return MODEL_METADATA[self]
MODEL_METADATA = {
LlmModel.O1_PREVIEW: ModelMetadata("openai", 32000, cost_factor=60),
@@ -77,11 +97,17 @@ MODEL_METADATA = {
# Limited to 16k during preview
LlmModel.LLAMA3_1_70B: ModelMetadata("groq", 131072, cost_factor=15),
LlmModel.LLAMA3_1_8B: ModelMetadata("groq", 131072, cost_factor=13),
LlmModel.OLLAMA_LLAMA3_8B: ModelMetadata("ollama", 8192, cost_factor=7),
LlmModel.OLLAMA_LLAMA3_405B: ModelMetadata("ollama", 8192, cost_factor=11),
}
for model in LlmModel:
if settings.config.app_env == AppEnvironment.LOCAL:
MODEL_METADATA.update(
{
LlmModel.OLLAMA_LLAMA3_8B: ModelMetadata("ollama", 8192, cost_factor=7),
LlmModel.OLLAMA_LLAMA3_405B: ModelMetadata("ollama", 8192, cost_factor=11),
}
)
for model in LlmModel.all_models():
if model not in MODEL_METADATA:
raise ValueError(f"Missing MODEL_METADATA metadata for model: {model}")