Ollama - Remote hosts (#8234)

### Background

Currently, AutoGPT only supports ollama servers running locally. Often,
this is not the case as the ollama server could be running on a more
suited instance, such as a Jetson board. This PR adds "ollama host" to
the input of all LLM blocks, allowing users to select the ollama host
for the LLM blocks.

### Changes 🏗️

- Changes contained within blocks/llm.py:
    - Adding ollama host input to all LLM blocks
- Fixed incorrect parsing of prompt when passing to ollama in the
StructuredResponse block
    - Used ollama.Client instances to accomplish this.


### Testing 🔍

Tested all LLM blocks with Ollama remote hosts as well as with the
default localhost value.


### Related issues
https://github.com/Significant-Gravitas/AutoGPT/issues/8225

---------

Co-authored-by: Fried-Squid <Fried-Squid>
Co-authored-by: Toran Bruce Richards <toran.richards@gmail.com>
Co-authored-by: Reinier van der Leer <pwuts@agpt.co>
Co-authored-by: Zamil Majdy <zamil.majdy@agpt.co>
Co-authored-by: Aarushi <50577581+aarushik93@users.noreply.github.com>
Co-authored-by: Nicholas Tindle <nicholas.tindle@agpt.co>
Co-authored-by: Nicholas Tindle <nicktindle@outlook.com>
This commit is contained in:
Ace
2024-12-13 00:02:49 +00:00
committed by GitHub
parent de3c096e23
commit 94a312a279
2 changed files with 35 additions and 1 deletions

View File

@@ -111,6 +111,7 @@ class LlmModel(str, Enum, metaclass=LlmModelMeta):
# Ollama models
OLLAMA_LLAMA3_8B = "llama3"
OLLAMA_LLAMA3_405B = "llama3.1:405b"
OLLAMA_DOLPHIN = "dolphin-mistral:latest"
# OpenRouter models
GEMINI_FLASH_1_5_8B = "google/gemini-flash-1.5"
GROK_BETA = "x-ai/grok-beta"
@@ -164,6 +165,7 @@ MODEL_METADATA = {
LlmModel.LLAMA3_1_8B: ModelMetadata("groq", 131072),
LlmModel.OLLAMA_LLAMA3_8B: ModelMetadata("ollama", 8192),
LlmModel.OLLAMA_LLAMA3_405B: ModelMetadata("ollama", 8192),
LlmModel.OLLAMA_DOLPHIN: ModelMetadata("ollama", 32768),
LlmModel.GEMINI_FLASH_1_5_8B: ModelMetadata("open_router", 8192),
LlmModel.GROK_BETA: ModelMetadata("open_router", 8192),
LlmModel.MISTRAL_NEMO: ModelMetadata("open_router", 4000),
@@ -240,6 +242,12 @@ class AIStructuredResponseGeneratorBlock(Block):
description="The maximum number of tokens to generate in the chat completion.",
)
ollama_host: str = SchemaField(
advanced=True,
default="localhost:11434",
description="Ollama host for local models",
)
class Output(BlockSchema):
response: dict[str, Any] = SchemaField(
description="The response object generated by the language model."
@@ -285,6 +293,7 @@ class AIStructuredResponseGeneratorBlock(Block):
prompt: list[dict],
json_format: bool,
max_tokens: int | None = None,
ollama_host: str = "localhost:11434",
) -> tuple[str, int, int]:
"""
Args:
@@ -293,6 +302,7 @@ class AIStructuredResponseGeneratorBlock(Block):
prompt: The prompt to send to the LLM.
json_format: Whether the response should be in JSON format.
max_tokens: The maximum number of tokens to generate in the chat completion.
ollama_host: The host for ollama to use
Returns:
The response from the LLM.
@@ -382,9 +392,10 @@ class AIStructuredResponseGeneratorBlock(Block):
response.usage.completion_tokens if response.usage else 0,
)
elif provider == "ollama":
client = ollama.Client(host=ollama_host)
sys_messages = [p["content"] for p in prompt if p["role"] == "system"]
usr_messages = [p["content"] for p in prompt if p["role"] != "system"]
response = ollama.generate(
response = client.generate(
model=llm_model.value,
prompt=f"{sys_messages}\n\n{usr_messages}",
stream=False,
@@ -484,6 +495,7 @@ class AIStructuredResponseGeneratorBlock(Block):
llm_model=llm_model,
prompt=prompt,
json_format=bool(input_data.expected_format),
ollama_host=input_data.ollama_host,
max_tokens=input_data.max_tokens,
)
self.merge_stats(
@@ -566,6 +578,11 @@ class AITextGeneratorBlock(Block):
prompt_values: dict[str, str] = SchemaField(
advanced=False, default={}, description="Values used to fill in the prompt."
)
ollama_host: str = SchemaField(
advanced=True,
default="localhost:11434",
description="Ollama host for local models",
)
max_tokens: int | None = SchemaField(
advanced=True,
default=None,
@@ -656,6 +673,11 @@ class AITextSummarizerBlock(Block):
description="The number of overlapping tokens between chunks to maintain context.",
ge=0,
)
ollama_host: str = SchemaField(
advanced=True,
default="localhost:11434",
description="Ollama host for local models",
)
class Output(BlockSchema):
summary: str = SchemaField(description="The final summary of the text.")
@@ -794,6 +816,11 @@ class AIConversationBlock(Block):
default=None,
description="The maximum number of tokens to generate in the chat completion.",
)
ollama_host: str = SchemaField(
advanced=True,
default="localhost:11434",
description="Ollama host for local models",
)
class Output(BlockSchema):
response: str = SchemaField(
@@ -891,6 +918,11 @@ class AIListGeneratorBlock(Block):
default=None,
description="The maximum number of tokens to generate in the chat completion.",
)
ollama_host: str = SchemaField(
advanced=True,
default="localhost:11434",
description="Ollama host for local models",
)
class Output(BlockSchema):
generated_list: List[str] = SchemaField(description="The generated list.")
@@ -1042,6 +1074,7 @@ class AIListGeneratorBlock(Block):
credentials=input_data.credentials,
model=input_data.model,
expected_format={}, # Do not use structured response
ollama_host=input_data.ollama_host,
),
credentials=credentials,
)

View File

@@ -53,6 +53,7 @@ MODEL_COST: dict[LlmModel, int] = {
LlmModel.LLAMA3_1_8B: 1,
LlmModel.OLLAMA_LLAMA3_8B: 1,
LlmModel.OLLAMA_LLAMA3_405B: 1,
LlmModel.OLLAMA_DOLPHIN: 1,
LlmModel.GEMINI_FLASH_1_5_8B: 1,
LlmModel.GROK_BETA: 5,
LlmModel.MISTRAL_NEMO: 1,