Compare commits

...

2 Commits

Author SHA1 Message Date
Aarushi
7cfea2a480 wip 2024-10-08 12:36:45 +01:00
Aarushi
0fbe0a0fa1 enable ollama only when local 2024-10-08 10:48:59 +01:00

View File

@@ -12,6 +12,9 @@ from groq import Groq
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
from backend.data.model import BlockSecret, SchemaField, SecretField from backend.data.model import BlockSecret, SchemaField, SecretField
from backend.util import json from backend.util import json
from backend.util.settings import AppEnvironment, Settings
settings = Settings()
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -50,14 +53,31 @@ class LlmModel(str, Enum):
LLAMA3_1_405B = "llama-3.1-405b-reasoning" LLAMA3_1_405B = "llama-3.1-405b-reasoning"
LLAMA3_1_70B = "llama-3.1-70b-versatile" LLAMA3_1_70B = "llama-3.1-70b-versatile"
LLAMA3_1_8B = "llama-3.1-8b-instant" LLAMA3_1_8B = "llama-3.1-8b-instant"
@classmethod
def _missing_(cls, value):
if settings.config.app_env == AppEnvironment.LOCAL:
if value == "llama3":
return cls.OLLAMA_LLAMA3_8B
elif value == "llama3.1:405b":
return cls.OLLAMA_LLAMA3_405B
return None
@property
def metadata(self) -> "ModelMetadata":
return MODEL_METADATA[self]
@classmethod
def all_models(cls) -> List["LlmModel"]:
models = list(cls)
if settings.config.app_env == AppEnvironment.LOCAL:
models.extend([cls.OLLAMA_LLAMA3_8B, cls.OLLAMA_LLAMA3_405B])
return models
# Ollama models # Ollama models
OLLAMA_LLAMA3_8B = "llama3" OLLAMA_LLAMA3_8B = "llama3"
OLLAMA_LLAMA3_405B = "llama3.1:405b" OLLAMA_LLAMA3_405B = "llama3.1:405b"
@property
def metadata(self) -> ModelMetadata:
return MODEL_METADATA[self]
MODEL_METADATA = { MODEL_METADATA = {
LlmModel.O1_PREVIEW: ModelMetadata("openai", 32000, cost_factor=60), LlmModel.O1_PREVIEW: ModelMetadata("openai", 32000, cost_factor=60),
@@ -77,11 +97,17 @@ MODEL_METADATA = {
# Limited to 16k during preview # Limited to 16k during preview
LlmModel.LLAMA3_1_70B: ModelMetadata("groq", 131072, cost_factor=15), LlmModel.LLAMA3_1_70B: ModelMetadata("groq", 131072, cost_factor=15),
LlmModel.LLAMA3_1_8B: ModelMetadata("groq", 131072, cost_factor=13), LlmModel.LLAMA3_1_8B: ModelMetadata("groq", 131072, cost_factor=13),
LlmModel.OLLAMA_LLAMA3_8B: ModelMetadata("ollama", 8192, cost_factor=7),
LlmModel.OLLAMA_LLAMA3_405B: ModelMetadata("ollama", 8192, cost_factor=11),
} }
for model in LlmModel: if settings.config.app_env == AppEnvironment.LOCAL:
MODEL_METADATA.update(
{
LlmModel.OLLAMA_LLAMA3_8B: ModelMetadata("ollama", 8192, cost_factor=7),
LlmModel.OLLAMA_LLAMA3_405B: ModelMetadata("ollama", 8192, cost_factor=11),
}
)
for model in LlmModel.all_models():
if model not in MODEL_METADATA: if model not in MODEL_METADATA:
raise ValueError(f"Missing MODEL_METADATA metadata for model: {model}") raise ValueError(f"Missing MODEL_METADATA metadata for model: {model}")