Merge remote-tracking branch 'origin/dev' into swiftyos/open-1920-marketplace-home-components

This commit is contained in:
SwiftyOS
2024-10-23 15:00:14 +02:00
46 changed files with 1005 additions and 270 deletions

View File

@@ -0,0 +1,182 @@
name: AutoGPT Platform - Build, Push, and Deploy Prod Environment
on:
release:
types: [published]
permissions:
contents: 'read'
id-token: 'write'
env:
PROJECT_ID: ${{ secrets.GCP_PROJECT_ID }}
GKE_CLUSTER: prod-gke-cluster
GKE_ZONE: us-central1-a
NAMESPACE: prod-agpt
jobs:
migrate:
environment: production
name: Run migrations for AutoGPT Platform
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@v2
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: '3.11'
- name: Install Python dependencies
run: |
python -m pip install --upgrade pip
pip install prisma
- name: Run Backend Migrations
working-directory: ./autogpt_platform/backend
run: |
python -m prisma migrate deploy
env:
DATABASE_URL: ${{ secrets.BACKEND_DATABASE_URL }}
- name: Run Market Migrations
working-directory: ./autogpt_platform/market
run: |
python -m prisma migrate deploy
env:
DATABASE_URL: ${{ secrets.MARKET_DATABASE_URL }}
build-push-deploy:
environment: production
name: Build, Push, and Deploy
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@v2
with:
fetch-depth: 0
- id: 'auth'
uses: 'google-github-actions/auth@v1'
with:
workload_identity_provider: 'projects/638488734936/locations/global/workloadIdentityPools/prod-pool/providers/github'
service_account: 'prod-github-actions-sa@agpt-prod.iam.gserviceaccount.com'
token_format: 'access_token'
create_credentials_file: true
- name: 'Set up Cloud SDK'
uses: 'google-github-actions/setup-gcloud@v1'
- name: 'Configure Docker'
run: |
gcloud auth configure-docker us-east1-docker.pkg.dev
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v1
- name: Cache Docker layers
uses: actions/cache@v2
with:
path: /tmp/.buildx-cache
key: ${{ runner.os }}-buildx-${{ github.sha }}
restore-keys: |
${{ runner.os }}-buildx-
- name: Check for changes
id: check_changes
run: |
git fetch origin master
BACKEND_CHANGED=$(git diff --name-only origin/master HEAD | grep "^autogpt_platform/backend/" && echo "true" || echo "false")
FRONTEND_CHANGED=$(git diff --name-only origin/master HEAD | grep "^autogpt_platform/frontend/" && echo "true" || echo "false")
MARKET_CHANGED=$(git diff --name-only origin/master HEAD | grep "^autogpt_platform/market/" && echo "true" || echo "false")
echo "backend_changed=$BACKEND_CHANGED" >> $GITHUB_OUTPUT
echo "frontend_changed=$FRONTEND_CHANGED" >> $GITHUB_OUTPUT
echo "market_changed=$MARKET_CHANGED" >> $GITHUB_OUTPUT
- name: Get GKE credentials
uses: 'google-github-actions/get-gke-credentials@v1'
with:
cluster_name: ${{ env.GKE_CLUSTER }}
location: ${{ env.GKE_ZONE }}
- name: Build and Push Backend
if: steps.check_changes.outputs.backend_changed == 'true'
uses: docker/build-push-action@v2
with:
context: .
file: ./autogpt_platform/backend/Dockerfile
push: true
tags: us-east1-docker.pkg.dev/agpt-prod/agpt-backend-prod/agpt-backend-prod:${{ github.sha }}
cache-from: type=local,src=/tmp/.buildx-cache
cache-to: type=local,dest=/tmp/.buildx-cache-new,mode=max
- name: Build and Push Frontend
if: steps.check_changes.outputs.frontend_changed == 'true'
uses: docker/build-push-action@v2
with:
context: .
file: ./autogpt_platform/frontend/Dockerfile
push: true
tags: us-east1-docker.pkg.dev/agpt-prod/agpt-frontend-prod/agpt-frontend-prod:${{ github.sha }}
cache-from: type=local,src=/tmp/.buildx-cache
cache-to: type=local,dest=/tmp/.buildx-cache-new,mode=max
- name: Build and Push Market
if: steps.check_changes.outputs.market_changed == 'true'
uses: docker/build-push-action@v2
with:
context: .
file: ./autogpt_platform/market/Dockerfile
push: true
tags: us-east1-docker.pkg.dev/agpt-prod/agpt-market-prod/agpt-market-prod:${{ github.sha }}
cache-from: type=local,src=/tmp/.buildx-cache
cache-to: type=local,dest=/tmp/.buildx-cache-new,mode=max
- name: Move cache
run: |
rm -rf /tmp/.buildx-cache
mv /tmp/.buildx-cache-new /tmp/.buildx-cache
- name: Set up Helm
uses: azure/setup-helm@v1
with:
version: v3.4.0
- name: Deploy Backend
if: steps.check_changes.outputs.backend_changed == 'true'
run: |
helm upgrade autogpt-server ./autogpt-server \
--namespace ${{ env.NAMESPACE }} \
-f autogpt-server/values.yaml \
-f autogpt-server/values.prod.yaml \
--set image.tag=${{ github.sha }}
- name: Deploy Websocket
if: steps.check_changes.outputs.backend_changed == 'true'
run: |
helm upgrade autogpt-websocket-server ./autogpt-websocket-server \
--namespace ${{ env.NAMESPACE }} \
-f autogpt-websocket-server/values.yaml \
-f autogpt-websocket-server/values.prod.yaml \
--set image.tag=${{ github.sha }}
- name: Deploy Market
if: steps.check_changes.outputs.market_changed == 'true'
run: |
helm upgrade autogpt-market ./autogpt-market \
--namespace ${{ env.NAMESPACE }} \
-f autogpt-market/values.yaml \
-f autogpt-market/values.prod.yaml \
--set image.tag=${{ github.sha }}
- name: Deploy Frontend
if: steps.check_changes.outputs.frontend_changed == 'true'
run: |
helm upgrade autogpt-builder ./autogpt-builder \
--namespace ${{ env.NAMESPACE }} \
-f autogpt-builder/values.yaml \
-f autogpt-builder/values.prod.yaml \
--set image.tag=${{ github.sha }}

View File

@@ -0,0 +1,186 @@
name: AutoGPT Platform - Build, Push, and Deploy Dev Environment
on:
push:
branches: [ dev ]
paths:
- 'autogpt_platform/backend/**'
- 'autogpt_platform/frontend/**'
- 'autogpt_platform/market/**'
permissions:
contents: 'read'
id-token: 'write'
env:
PROJECT_ID: ${{ secrets.GCP_PROJECT_ID }}
GKE_CLUSTER: dev-gke-cluster
GKE_ZONE: us-central1-a
NAMESPACE: dev-agpt
jobs:
migrate:
environment: develop
name: Run migrations for AutoGPT Platform
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@v2
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: '3.11'
- name: Install Python dependencies
run: |
python -m pip install --upgrade pip
pip install prisma
- name: Run Backend Migrations
working-directory: ./autogpt_platform/backend
run: |
python -m prisma migrate deploy
env:
DATABASE_URL: ${{ secrets.BACKEND_DATABASE_URL }}
- name: Run Market Migrations
working-directory: ./autogpt_platform/market
run: |
python -m prisma migrate deploy
env:
DATABASE_URL: ${{ secrets.MARKET_DATABASE_URL }}
build-push-deploy:
name: Build, Push, and Deploy
needs: migrate
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@v2
with:
fetch-depth: 0
- id: 'auth'
uses: 'google-github-actions/auth@v1'
with:
workload_identity_provider: 'projects/638488734936/locations/global/workloadIdentityPools/dev-pool/providers/github'
service_account: 'dev-github-actions-sa@agpt-dev.iam.gserviceaccount.com'
token_format: 'access_token'
create_credentials_file: true
- name: 'Set up Cloud SDK'
uses: 'google-github-actions/setup-gcloud@v1'
- name: 'Configure Docker'
run: |
gcloud auth configure-docker us-east1-docker.pkg.dev
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v1
- name: Cache Docker layers
uses: actions/cache@v2
with:
path: /tmp/.buildx-cache
key: ${{ runner.os }}-buildx-${{ github.sha }}
restore-keys: |
${{ runner.os }}-buildx-
- name: Check for changes
id: check_changes
run: |
git fetch origin dev
BACKEND_CHANGED=$(git diff --name-only origin/dev HEAD | grep "^autogpt_platform/backend/" && echo "true" || echo "false")
FRONTEND_CHANGED=$(git diff --name-only origin/dev HEAD | grep "^autogpt_platform/frontend/" && echo "true" || echo "false")
MARKET_CHANGED=$(git diff --name-only origin/dev HEAD | grep "^autogpt_platform/market/" && echo "true" || echo "false")
echo "backend_changed=$BACKEND_CHANGED" >> $GITHUB_OUTPUT
echo "frontend_changed=$FRONTEND_CHANGED" >> $GITHUB_OUTPUT
echo "market_changed=$MARKET_CHANGED" >> $GITHUB_OUTPUT
- name: Get GKE credentials
uses: 'google-github-actions/get-gke-credentials@v1'
with:
cluster_name: ${{ env.GKE_CLUSTER }}
location: ${{ env.GKE_ZONE }}
- name: Build and Push Backend
if: steps.check_changes.outputs.backend_changed == 'true'
uses: docker/build-push-action@v2
with:
context: .
file: ./autogpt_platform/backend/Dockerfile
push: true
tags: us-east1-docker.pkg.dev/agpt-dev/agpt-backend-dev/agpt-backend-dev:${{ github.sha }}
cache-from: type=local,src=/tmp/.buildx-cache
cache-to: type=local,dest=/tmp/.buildx-cache-new,mode=max
- name: Build and Push Frontend
if: steps.check_changes.outputs.frontend_changed == 'true'
uses: docker/build-push-action@v2
with:
context: .
file: ./autogpt_platform/frontend/Dockerfile
push: true
tags: us-east1-docker.pkg.dev/agpt-dev/agpt-frontend-dev/agpt-frontend-dev:${{ github.sha }}
cache-from: type=local,src=/tmp/.buildx-cache
cache-to: type=local,dest=/tmp/.buildx-cache-new,mode=max
- name: Build and Push Market
if: steps.check_changes.outputs.market_changed == 'true'
uses: docker/build-push-action@v2
with:
context: .
file: ./autogpt_platform/market/Dockerfile
push: true
tags: us-east1-docker.pkg.dev/agpt-dev/agpt-market-dev/agpt-market-dev:${{ github.sha }}
cache-from: type=local,src=/tmp/.buildx-cache
cache-to: type=local,dest=/tmp/.buildx-cache-new,mode=max
- name: Move cache
run: |
rm -rf /tmp/.buildx-cache
mv /tmp/.buildx-cache-new /tmp/.buildx-cache
- name: Set up Helm
uses: azure/setup-helm@v1
with:
version: v3.4.0
- name: Deploy Backend
if: steps.check_changes.outputs.backend_changed == 'true'
run: |
helm upgrade autogpt-server ./autogpt-server \
--namespace ${{ env.NAMESPACE }} \
-f autogpt-server/values.yaml \
-f autogpt-server/values.dev.yaml \
--set image.tag=${{ github.sha }}
- name: Deploy Websocket
if: steps.check_changes.outputs.backend_changed == 'true'
run: |
helm upgrade autogpt-websocket-server ./autogpt-websocket-server \
--namespace ${{ env.NAMESPACE }} \
-f autogpt-websocket-server/values.yaml \
-f autogpt-websocket-server/values.dev.yaml \
--set image.tag=${{ github.sha }}
- name: Deploy Market
if: steps.check_changes.outputs.market_changed == 'true'
run: |
helm upgrade autogpt-market ./autogpt-market \
--namespace ${{ env.NAMESPACE }} \
-f autogpt-market/values.yaml \
-f autogpt-market/values.dev.yaml \
--set image.tag=${{ github.sha }}
- name: Deploy Frontend
if: steps.check_changes.outputs.frontend_changed == 'true'
run: |
helm upgrade autogpt-builder ./autogpt-builder \
--namespace ${{ env.NAMESPACE }} \
-f autogpt-builder/values.yaml \
-f autogpt-builder/values.dev.yaml \
--set image.tag=${{ github.sha }}

View File

@@ -39,10 +39,27 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: Free Disk Space (Ubuntu)
uses: jlumbroso/free-disk-space@main
with:
# this might remove tools that are actually needed,
# if set to "true" but frees about 6 GB
tool-cache: false
# all of these default to true, but feel free to set to
# "false" if necessary for your workflow
android: false
dotnet: false
haskell: false
large-packages: true
docker-images: true
swap-storage: true
- name: Checkout repository
uses: actions/checkout@v4
with:
submodules: recursive
- name: Set up Node.js
uses: actions/setup-node@v4
with:

View File

@@ -1,11 +1,12 @@
import secrets
from datetime import datetime, timedelta, timezone
from typing import TYPE_CHECKING, cast
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from redis import Redis
from supabase import Client
from backend.executor.database import DatabaseManager
from autogpt_libs.utils.cache import thread_cached_property
from autogpt_libs.utils.synchronize import RedisKeyedMutex
from .types import (
@@ -18,9 +19,14 @@ from .types import (
class SupabaseIntegrationCredentialsStore:
def __init__(self, supabase: "Client", redis: "Redis"):
self.supabase = supabase
def __init__(self, redis: "Redis"):
self.locks = RedisKeyedMutex(redis)
@thread_cached_property
def db_manager(self) -> "DatabaseManager":
from backend.executor.database import DatabaseManager
from backend.util.service import get_service_client
return get_service_client(DatabaseManager)
def add_creds(self, user_id: str, credentials: Credentials) -> None:
with self.locked_user_metadata(user_id):
@@ -35,7 +41,9 @@ class SupabaseIntegrationCredentialsStore:
def get_all_creds(self, user_id: str) -> list[Credentials]:
user_metadata = self._get_user_metadata(user_id)
return UserMetadata.model_validate(user_metadata).integration_credentials
return UserMetadata.model_validate(
user_metadata.model_dump()
).integration_credentials
def get_creds_by_id(self, user_id: str, credentials_id: str) -> Credentials | None:
all_credentials = self.get_all_creds(user_id)
@@ -90,9 +98,7 @@ class SupabaseIntegrationCredentialsStore:
]
self._set_user_integration_creds(user_id, filtered_credentials)
async def store_state_token(
self, user_id: str, provider: str, scopes: list[str]
) -> str:
def store_state_token(self, user_id: str, provider: str, scopes: list[str]) -> str:
token = secrets.token_urlsafe(32)
expires_at = datetime.now(timezone.utc) + timedelta(minutes=10)
@@ -105,17 +111,17 @@ class SupabaseIntegrationCredentialsStore:
with self.locked_user_metadata(user_id):
user_metadata = self._get_user_metadata(user_id)
oauth_states = user_metadata.get("integration_oauth_states", [])
oauth_states = user_metadata.integration_oauth_states
oauth_states.append(state.model_dump())
user_metadata["integration_oauth_states"] = oauth_states
user_metadata.integration_oauth_states = oauth_states
self.supabase.auth.admin.update_user_by_id(
user_id, {"user_metadata": user_metadata}
self.db_manager.update_user_metadata(
user_id=user_id, metadata=user_metadata
)
return token
async def get_any_valid_scopes_from_state_token(
def get_any_valid_scopes_from_state_token(
self, user_id: str, token: str, provider: str
) -> list[str]:
"""
@@ -126,7 +132,7 @@ class SupabaseIntegrationCredentialsStore:
THE CODE FOR TOKENS.
"""
user_metadata = self._get_user_metadata(user_id)
oauth_states = user_metadata.get("integration_oauth_states", [])
oauth_states = user_metadata.integration_oauth_states
now = datetime.now(timezone.utc)
valid_state = next(
@@ -145,10 +151,10 @@ class SupabaseIntegrationCredentialsStore:
return []
async def verify_state_token(self, user_id: str, token: str, provider: str) -> bool:
def verify_state_token(self, user_id: str, token: str, provider: str) -> bool:
with self.locked_user_metadata(user_id):
user_metadata = self._get_user_metadata(user_id)
oauth_states = user_metadata.get("integration_oauth_states", [])
oauth_states = user_metadata.integration_oauth_states
now = datetime.now(timezone.utc)
valid_state = next(
@@ -165,10 +171,8 @@ class SupabaseIntegrationCredentialsStore:
if valid_state:
# Remove the used state
oauth_states.remove(valid_state)
user_metadata["integration_oauth_states"] = oauth_states
self.supabase.auth.admin.update_user_by_id(
user_id, {"user_metadata": user_metadata}
)
user_metadata.integration_oauth_states = oauth_states
self.db_manager.update_user_metadata(user_id, user_metadata)
return True
return False
@@ -177,19 +181,13 @@ class SupabaseIntegrationCredentialsStore:
self, user_id: str, credentials: list[Credentials]
) -> None:
raw_metadata = self._get_user_metadata(user_id)
raw_metadata.update(
{"integration_credentials": [c.model_dump() for c in credentials]}
)
self.supabase.auth.admin.update_user_by_id(
user_id, {"user_metadata": raw_metadata}
)
raw_metadata.integration_credentials = [c.model_dump() for c in credentials]
self.db_manager.update_user_metadata(user_id, raw_metadata)
def _get_user_metadata(self, user_id: str) -> UserMetadataRaw:
response = self.supabase.auth.admin.get_user_by_id(user_id)
if not response.user:
raise ValueError(f"User with ID {user_id} not found")
return cast(UserMetadataRaw, response.user.user_metadata)
metadata: UserMetadataRaw = self.db_manager.get_user_metadata(user_id=user_id)
return metadata
def locked_user_metadata(self, user_id: str):
key = (self.supabase.supabase_url, f"user:{user_id}", "metadata")
key = (self.db_manager, f"user:{user_id}", "metadata")
return self.locks.locked(key)

View File

@@ -56,6 +56,7 @@ class OAuthState(BaseModel):
token: str
provider: str
expires_at: int
scopes: list[str]
"""Unix timestamp (seconds) indicating when this OAuth state expires"""
@@ -64,6 +65,6 @@ class UserMetadata(BaseModel):
integration_oauth_states: list[OAuthState] = Field(default_factory=list)
class UserMetadataRaw(TypedDict, total=False):
integration_credentials: list[dict]
integration_oauth_states: list[dict]
class UserMetadataRaw(BaseModel):
integration_credentials: list[dict] = Field(default_factory=list)
integration_oauth_states: list[dict] = Field(default_factory=list)

View File

@@ -0,0 +1,27 @@
import threading
from functools import wraps
from typing import Callable, ParamSpec, TypeVar
T = TypeVar("T")
P = ParamSpec("P")
R = TypeVar("R")
def thread_cached(func: Callable[P, R]) -> Callable[P, R]:
thread_local = threading.local()
@wraps(func)
def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
cache = getattr(thread_local, "cache", None)
if cache is None:
cache = thread_local.cache = {}
key = (args, tuple(sorted(kwargs.items())))
if key not in cache:
cache[key] = func(*args, **kwargs)
return cache[key]
return wrapper
def thread_cached_property(func: Callable[[T], R]) -> property:
return property(thread_cached(func))

View File

@@ -20,13 +20,13 @@ PYRO_HOST=localhost
SENTRY_DSN=
## User auth with Supabase is required for any of the 3rd party integrations with auth to work.
ENABLE_AUTH=false
ENABLE_AUTH=true
SUPABASE_URL=http://localhost:8000
SUPABASE_SERVICE_ROLE_KEY=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyAgCiAgICAicm9sZSI6ICJzZXJ2aWNlX3JvbGUiLAogICAgImlzcyI6ICJzdXBhYmFzZS1kZW1vIiwKICAgICJpYXQiOiAxNjQxNzY5MjAwLAogICAgImV4cCI6IDE3OTk1MzU2MDAKfQ.DaYlNEoUrrEn2Ig7tqibS-PHK5vgusbcbo7X36XVt4Q
SUPABASE_JWT_SECRET=your-super-secret-jwt-token-with-at-least-32-characters-long
# For local development, you may need to set FRONTEND_BASE_URL for the OAuth flow for integrations to work.
# FRONTEND_BASE_URL=http://localhost:3000
FRONTEND_BASE_URL=http://localhost:3000
## == INTEGRATION CREDENTIALS == ##
# Each set of server side credentials is required for the corresponding 3rd party

View File

@@ -8,7 +8,7 @@ WORKDIR /app
# Install build dependencies
RUN apt-get update \
&& apt-get install -y build-essential curl ffmpeg wget libcurl4-gnutls-dev libexpat1-dev gettext libz-dev libssl-dev postgresql-client git \
&& apt-get install -y build-essential curl ffmpeg wget libcurl4-gnutls-dev libexpat1-dev libpq5 gettext libz-dev libssl-dev postgresql-client git \
&& apt-get clean \
&& rm -rf /var/lib/apt/lists/*

View File

@@ -37,7 +37,7 @@ We use the Poetry to manage the dependencies. To set up the project, follow thes
5. Generate the Prisma client
```sh
poetry run prisma generate --schema postgres/schema.prisma
poetry run prisma generate
```
@@ -61,7 +61,7 @@ We use the Poetry to manage the dependencies. To set up the project, follow thes
```sh
cd ../backend
prisma migrate dev --schema postgres/schema.prisma
prisma migrate deploy
```
## Running The Server

View File

@@ -58,17 +58,18 @@ We use the Poetry to manage the dependencies. To set up the project, follow thes
6. Migrate the database. Be careful because this deletes current data in the database.
```sh
docker compose up db redis -d
poetry run prisma migrate dev
docker compose up db -d
poetry run prisma migrate deploy
```
## Running The Server
### Starting the server without Docker
Run the following command to build the dockerfiles:
Run the following command to run database in docker but the application locally:
```sh
docker compose --profile local up deps --build --detach
poetry run app
```

View File

@@ -2,6 +2,7 @@ import importlib
import os
import re
from pathlib import Path
from typing import Type, TypeVar
from backend.data.block import Block
@@ -24,28 +25,31 @@ for module in modules:
AVAILABLE_MODULES.append(module)
# Load all Block instances from the available modules
AVAILABLE_BLOCKS = {}
AVAILABLE_BLOCKS: dict[str, Type[Block]] = {}
def all_subclasses(clz):
subclasses = clz.__subclasses__()
T = TypeVar("T")
def all_subclasses(cls: Type[T]) -> list[Type[T]]:
subclasses = cls.__subclasses__()
for subclass in subclasses:
subclasses += all_subclasses(subclass)
return subclasses
for cls in all_subclasses(Block):
name = cls.__name__
for block_cls in all_subclasses(Block):
name = block_cls.__name__
if cls.__name__.endswith("Base"):
if block_cls.__name__.endswith("Base"):
continue
if not cls.__name__.endswith("Block"):
if not block_cls.__name__.endswith("Block"):
raise ValueError(
f"Block class {cls.__name__} does not end with 'Block', If you are creating an abstract class, please name the class with 'Base' at the end"
f"Block class {block_cls.__name__} does not end with 'Block', If you are creating an abstract class, please name the class with 'Base' at the end"
)
block = cls()
block = block_cls.create()
if not isinstance(block.id, str) or len(block.id) != 36:
raise ValueError(f"Block ID {block.name} error: {block.id} is not a valid UUID")
@@ -87,6 +91,6 @@ for cls in all_subclasses(Block):
if block.disabled:
continue
AVAILABLE_BLOCKS[block.id] = block
AVAILABLE_BLOCKS[block.id] = block_cls
__all__ = ["AVAILABLE_MODULES", "AVAILABLE_BLOCKS"]

View File

@@ -62,7 +62,7 @@ class LlmModel(str, Enum, metaclass=LlmModelMeta):
GPT4_TURBO = "gpt-4-turbo"
GPT3_5_TURBO = "gpt-3.5-turbo"
# Anthropic models
CLAUDE_3_5_SONNET = "claude-3-5-sonnet-20240620"
CLAUDE_3_5_SONNET = "claude-3-5-sonnet-latest"
CLAUDE_3_HAIKU = "claude-3-haiku-20240307"
# Groq models
LLAMA3_8B = "llama3-8b-8192"
@@ -122,6 +122,17 @@ for model in LlmModel:
raise ValueError(f"Missing MODEL_METADATA metadata for model: {model}")
class MessageRole(str, Enum):
SYSTEM = "system"
USER = "user"
ASSISTANT = "assistant"
class Message(BlockSchema):
role: MessageRole
content: str
class AIStructuredResponseGeneratorBlock(Block):
class Input(BlockSchema):
prompt: str = SchemaField(
@@ -144,6 +155,10 @@ class AIStructuredResponseGeneratorBlock(Block):
default="",
description="The system prompt to provide additional context to the model.",
)
conversation_history: list[Message] = SchemaField(
default=[],
description="The conversation history to provide context for the prompt.",
)
retry: int = SchemaField(
title="Retry Count",
default=3,
@@ -152,6 +167,11 @@ class AIStructuredResponseGeneratorBlock(Block):
prompt_values: dict[str, str] = SchemaField(
advanced=False, default={}, description="Values used to fill in the prompt."
)
max_tokens: int | None = SchemaField(
advanced=True,
default=None,
description="The maximum number of tokens to generate in the chat completion.",
)
class Output(BlockSchema):
response: dict[str, Any] = SchemaField(
@@ -177,26 +197,47 @@ class AIStructuredResponseGeneratorBlock(Block):
},
test_output=("response", {"key1": "key1Value", "key2": "key2Value"}),
test_mock={
"llm_call": lambda *args, **kwargs: json.dumps(
{
"key1": "key1Value",
"key2": "key2Value",
}
"llm_call": lambda *args, **kwargs: (
json.dumps(
{
"key1": "key1Value",
"key2": "key2Value",
}
),
0,
0,
)
},
)
@staticmethod
def llm_call(
api_key: str, model: LlmModel, prompt: list[dict], json_format: bool
) -> str:
provider = model.metadata.provider
api_key: str,
llm_model: LlmModel,
prompt: list[dict],
json_format: bool,
max_tokens: int | None = None,
) -> tuple[str, int, int]:
"""
Args:
api_key: API key for the LLM provider.
llm_model: The LLM model to use.
prompt: The prompt to send to the LLM.
json_format: Whether the response should be in JSON format.
max_tokens: The maximum number of tokens to generate in the chat completion.
Returns:
The response from the LLM.
The number of tokens used in the prompt.
The number of tokens used in the completion.
"""
provider = llm_model.metadata.provider
if provider == "openai":
openai.api_key = api_key
response_format = None
if model in [LlmModel.O1_MINI, LlmModel.O1_PREVIEW]:
if llm_model in [LlmModel.O1_MINI, LlmModel.O1_PREVIEW]:
sys_messages = [p["content"] for p in prompt if p["role"] == "system"]
usr_messages = [p["content"] for p in prompt if p["role"] != "system"]
prompt = [
@@ -207,11 +248,17 @@ class AIStructuredResponseGeneratorBlock(Block):
response_format = {"type": "json_object"}
response = openai.chat.completions.create(
model=model.value,
model=llm_model.value,
messages=prompt, # type: ignore
response_format=response_format, # type: ignore
max_completion_tokens=max_tokens,
)
return (
response.choices[0].message.content or "",
response.usage.prompt_tokens if response.usage else 0,
response.usage.completion_tokens if response.usage else 0,
)
return response.choices[0].message.content or ""
elif provider == "anthropic":
system_messages = [p["content"] for p in prompt if p["role"] == "system"]
sysprompt = " ".join(system_messages)
@@ -229,13 +276,18 @@ class AIStructuredResponseGeneratorBlock(Block):
client = anthropic.Anthropic(api_key=api_key)
try:
response = client.messages.create(
model=model.value,
max_tokens=4096,
resp = client.messages.create(
model=llm_model.value,
system=sysprompt,
messages=messages,
max_tokens=max_tokens or 8192,
)
return (
resp.content[0].text if resp.content else "",
resp.usage.input_tokens,
resp.usage.output_tokens,
)
return response.content[0].text if response.content else ""
except anthropic.APIError as e:
error_message = f"Anthropic API error: {str(e)}"
logger.error(error_message)
@@ -244,23 +296,35 @@ class AIStructuredResponseGeneratorBlock(Block):
client = Groq(api_key=api_key)
response_format = {"type": "json_object"} if json_format else None
response = client.chat.completions.create(
model=model.value,
model=llm_model.value,
messages=prompt, # type: ignore
response_format=response_format, # type: ignore
max_tokens=max_tokens,
)
return (
response.choices[0].message.content or "",
response.usage.prompt_tokens if response.usage else 0,
response.usage.completion_tokens if response.usage else 0,
)
return response.choices[0].message.content or ""
elif provider == "ollama":
sys_messages = [p["content"] for p in prompt if p["role"] == "system"]
usr_messages = [p["content"] for p in prompt if p["role"] != "system"]
response = ollama.generate(
model=model.value,
prompt=prompt[0]["content"],
model=llm_model.value,
prompt=f"{sys_messages}\n\n{usr_messages}",
stream=False,
)
return (
response.get("response") or "",
response.get("prompt_eval_count") or 0,
response.get("eval_count") or 0,
)
return response["response"]
else:
raise ValueError(f"Unsupported LLM provider: {provider}")
def run(self, input_data: Input, **kwargs) -> BlockOutput:
logger.debug(f"Calling LLM with input data: {input_data}")
prompt = []
prompt = [p.model_dump() for p in input_data.conversation_history]
def trim_prompt(s: str) -> str:
lines = s.strip().split("\n")
@@ -289,7 +353,8 @@ class AIStructuredResponseGeneratorBlock(Block):
)
prompt.append({"role": "system", "content": sys_prompt})
prompt.append({"role": "user", "content": input_data.prompt})
if input_data.prompt:
prompt.append({"role": "user", "content": input_data.prompt})
def parse_response(resp: str) -> tuple[dict[str, Any], str | None]:
try:
@@ -305,19 +370,26 @@ class AIStructuredResponseGeneratorBlock(Block):
logger.info(f"LLM request: {prompt}")
retry_prompt = ""
model = input_data.model
llm_model = input_data.model
api_key = (
input_data.api_key.get_secret_value()
or LlmApiKeys[model.metadata.provider].get_secret_value()
or LlmApiKeys[llm_model.metadata.provider].get_secret_value()
)
for retry_count in range(input_data.retry):
try:
response_text = self.llm_call(
response_text, input_token, output_token = self.llm_call(
api_key=api_key,
model=model,
llm_model=llm_model,
prompt=prompt,
json_format=bool(input_data.expected_format),
max_tokens=input_data.max_tokens,
)
self.merge_stats(
{
"input_token_count": input_token,
"output_token_count": output_token,
}
)
logger.info(f"LLM attempt-{retry_count} response: {response_text}")
@@ -354,8 +426,15 @@ class AIStructuredResponseGeneratorBlock(Block):
)
prompt.append({"role": "user", "content": retry_prompt})
except Exception as e:
logger.error(f"Error calling LLM: {e}")
logger.exception(f"Error calling LLM: {e}")
retry_prompt = f"Error calling LLM: {e}"
finally:
self.merge_stats(
{
"llm_call_count": retry_count + 1,
"llm_retry_count": retry_count,
}
)
raise RuntimeError(retry_prompt)
@@ -386,6 +465,11 @@ class AITextGeneratorBlock(Block):
prompt_values: dict[str, str] = SchemaField(
advanced=False, default={}, description="Values used to fill in the prompt."
)
max_tokens: int | None = SchemaField(
advanced=True,
default=None,
description="The maximum number of tokens to generate in the chat completion.",
)
class Output(BlockSchema):
response: str = SchemaField(
@@ -405,15 +489,11 @@ class AITextGeneratorBlock(Block):
test_mock={"llm_call": lambda *args, **kwargs: "Response text"},
)
@staticmethod
def llm_call(input_data: AIStructuredResponseGeneratorBlock.Input) -> str:
object_block = AIStructuredResponseGeneratorBlock()
for output_name, output_data in object_block.run(input_data):
if output_name == "response":
return output_data["response"]
else:
raise RuntimeError(output_data)
raise ValueError("Failed to get a response from the LLM.")
def llm_call(self, input_data: AIStructuredResponseGeneratorBlock.Input) -> str:
block = AIStructuredResponseGeneratorBlock()
response = block.run_once(input_data, "response")
self.merge_stats(block.execution_stats)
return response["response"]
def run(self, input_data: Input, **kwargs) -> BlockOutput:
object_input_data = AIStructuredResponseGeneratorBlock.Input(
@@ -517,15 +597,11 @@ class AITextSummarizerBlock(Block):
return chunks
@staticmethod
def llm_call(
input_data: AIStructuredResponseGeneratorBlock.Input,
) -> dict[str, str]:
llm_block = AIStructuredResponseGeneratorBlock()
for output_name, output_data in llm_block.run(input_data):
if output_name == "response":
return output_data
raise ValueError("Failed to get a response from the LLM.")
def llm_call(self, input_data: AIStructuredResponseGeneratorBlock.Input) -> dict:
block = AIStructuredResponseGeneratorBlock()
response = block.run_once(input_data, "response")
self.merge_stats(block.execution_stats)
return response
def _summarize_chunk(self, chunk: str, input_data: Input) -> str:
prompt = f"Summarize the following text in a {input_data.style} form. Focus your summary on the topic of `{input_data.focus}` if present, otherwise just provide a general summary:\n\n```{chunk}```"
@@ -574,17 +650,6 @@ class AITextSummarizerBlock(Block):
] # Get the first yielded value
class MessageRole(str, Enum):
SYSTEM = "system"
USER = "user"
ASSISTANT = "assistant"
class Message(BlockSchema):
role: MessageRole
content: str
class AIConversationBlock(Block):
class Input(BlockSchema):
messages: List[Message] = SchemaField(
@@ -599,9 +664,9 @@ class AIConversationBlock(Block):
value="", description="API key for the chosen language model provider."
)
max_tokens: int | None = SchemaField(
advanced=True,
default=None,
description="The maximum number of tokens to generate in the chat completion.",
ge=1,
)
class Output(BlockSchema):
@@ -639,62 +704,22 @@ class AIConversationBlock(Block):
},
)
@staticmethod
def llm_call(
api_key: str,
model: LlmModel,
messages: List[dict[str, str]],
max_tokens: int | None = None,
) -> str:
provider = model.metadata.provider
if provider == "openai":
openai.api_key = api_key
response = openai.chat.completions.create(
model=model.value,
messages=messages, # type: ignore
max_tokens=max_tokens,
)
return response.choices[0].message.content or ""
elif provider == "anthropic":
client = anthropic.Anthropic(api_key=api_key)
response = client.messages.create(
model=model.value,
max_tokens=max_tokens or 4096,
messages=messages, # type: ignore
)
return response.content[0].text if response.content else ""
elif provider == "groq":
client = Groq(api_key=api_key)
response = client.chat.completions.create(
model=model.value,
messages=messages, # type: ignore
max_tokens=max_tokens,
)
return response.choices[0].message.content or ""
elif provider == "ollama":
response = ollama.chat(
model=model.value,
messages=messages, # type: ignore
stream=False, # type: ignore
)
return response["message"]["content"]
else:
raise ValueError(f"Unsupported LLM provider: {provider}")
def llm_call(self, input_data: AIStructuredResponseGeneratorBlock.Input) -> str:
block = AIStructuredResponseGeneratorBlock()
response = block.run_once(input_data, "response")
self.merge_stats(block.execution_stats)
return response["response"]
def run(self, input_data: Input, **kwargs) -> BlockOutput:
api_key = (
input_data.api_key.get_secret_value()
or LlmApiKeys[input_data.model.metadata.provider].get_secret_value()
)
messages = [message.model_dump() for message in input_data.messages]
response = self.llm_call(
api_key=api_key,
model=input_data.model,
messages=messages,
max_tokens=input_data.max_tokens,
AIStructuredResponseGeneratorBlock.Input(
prompt="",
api_key=input_data.api_key,
model=input_data.model,
conversation_history=input_data.messages,
max_tokens=input_data.max_tokens,
expected_format={},
)
)
yield "response", response
@@ -727,6 +752,11 @@ class AIListGeneratorBlock(Block):
ge=1,
le=5,
)
max_tokens: int | None = SchemaField(
advanced=True,
default=None,
description="The maximum number of tokens to generate in the chat completion.",
)
class Output(BlockSchema):
generated_list: List[str] = SchemaField(description="The generated list.")
@@ -781,11 +811,8 @@ class AIListGeneratorBlock(Block):
input_data: AIStructuredResponseGeneratorBlock.Input,
) -> dict[str, str]:
llm_block = AIStructuredResponseGeneratorBlock()
for output_name, output_data in llm_block.run(input_data):
if output_name == "response":
logger.debug(f"Received response from LLM: {output_data}")
return output_data
raise ValueError("Failed to get a response from the LLM.")
response = llm_block.run_once(input_data, "response")
return response
@staticmethod
def string_to_list(string):

View File

@@ -230,6 +230,11 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
self.disabled = disabled
self.static_output = static_output
self.block_type = block_type
self.execution_stats = {}
@classmethod
def create(cls: Type["Block"]) -> "Block":
return cls()
@abstractmethod
def run(self, input_data: BlockSchemaInputType, **kwargs) -> BlockOutput:
@@ -244,6 +249,26 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
"""
pass
def run_once(self, input_data: BlockSchemaInputType, output: str, **kwargs) -> Any:
for name, data in self.run(input_data, **kwargs):
if name == output:
return data
raise ValueError(f"{self.name} did not produce any output for {output}")
def merge_stats(self, stats: dict[str, Any]) -> dict[str, Any]:
for key, value in stats.items():
if isinstance(value, dict):
self.execution_stats.setdefault(key, {}).update(value)
elif isinstance(value, (int, float)):
self.execution_stats.setdefault(key, 0)
self.execution_stats[key] += value
elif isinstance(value, list):
self.execution_stats.setdefault(key, [])
self.execution_stats[key].extend(value)
else:
self.execution_stats[key] = value
return self.execution_stats
@property
def name(self):
return self.__class__.__name__
@@ -282,14 +307,15 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
# ======================= Block Helper Functions ======================= #
def get_blocks() -> dict[str, Block]:
def get_blocks() -> dict[str, Type[Block]]:
from backend.blocks import AVAILABLE_BLOCKS # noqa: E402
return AVAILABLE_BLOCKS
async def initialize_blocks() -> None:
for block in get_blocks().values():
for cls in get_blocks().values():
block = cls()
existing_block = await AgentBlock.prisma().find_first(
where={"OR": [{"id": block.id}, {"name": block.name}]}
)
@@ -324,4 +350,5 @@ async def initialize_blocks() -> None:
def get_block(block_id: str) -> Block | None:
return get_blocks().get(block_id)
cls = get_blocks().get(block_id)
return cls() if cls else None

View File

@@ -257,7 +257,7 @@ class Graph(GraphMeta):
block = get_block(node.block_id)
if not block:
blocks = {v.id: v.name for v in get_blocks().values()}
blocks = {v().id: v().name for v in get_blocks().values()}
raise ValueError(
f"{suffix}, {node.block_id} is invalid block id, available blocks: {blocks}"
)

View File

@@ -1,6 +1,8 @@
from typing import Optional
from autogpt_libs.supabase_integration_credentials_store.types import UserMetadataRaw
from fastapi import HTTPException
from prisma import Json
from prisma.models import User
from backend.data.db import prisma
@@ -35,16 +37,32 @@ async def get_user_by_id(user_id: str) -> Optional[User]:
return User.model_validate(user) if user else None
async def create_default_user(enable_auth: str) -> Optional[User]:
if not enable_auth.lower() == "true":
user = await prisma.user.find_unique(where={"id": DEFAULT_USER_ID})
if not user:
user = await prisma.user.create(
data={
"id": DEFAULT_USER_ID,
"email": "default@example.com",
"name": "Default User",
}
)
return User.model_validate(user)
return None
async def create_default_user() -> Optional[User]:
user = await prisma.user.find_unique(where={"id": DEFAULT_USER_ID})
if not user:
user = await prisma.user.create(
data={
"id": DEFAULT_USER_ID,
"email": "default@example.com",
"name": "Default User",
}
)
return User.model_validate(user)
async def get_user_metadata(user_id: str) -> UserMetadataRaw:
user = await User.prisma().find_unique_or_raise(
where={"id": user_id},
)
return (
UserMetadataRaw.model_validate(user.metadata)
if user.metadata
else UserMetadataRaw()
)
async def update_user_metadata(user_id: str, metadata: UserMetadataRaw):
await User.prisma().update(
where={"id": user_id},
data={"metadata": Json(metadata.model_dump())},
)

View File

@@ -16,6 +16,7 @@ from backend.data.execution import (
)
from backend.data.graph import get_graph, get_node
from backend.data.queue import RedisEventQueue
from backend.data.user import get_user_metadata, update_user_metadata
from backend.util.service import AppService, expose
from backend.util.settings import Config
@@ -26,11 +27,15 @@ R = TypeVar("R")
class DatabaseManager(AppService):
def __init__(self):
super().__init__(port=Config().database_api_port)
super().__init__()
self.use_db = True
self.use_redis = True
self.event_queue = RedisEventQueue()
@classmethod
def get_port(cls) -> int:
return Config().database_api_port
@expose
def send_execution_update(self, execution_result_dict: dict[Any, Any]):
self.event_queue.put(ExecutionResult(**execution_result_dict))
@@ -73,3 +78,7 @@ class DatabaseManager(AppService):
Callable[[Any, str, int, str, dict[str, str], float, float], int],
exposed_run_and_wait(user_credit_model.spend_credits),
)
# User + User Metadata
get_user_metadata = exposed_run_and_wait(get_user_metadata)
update_user_metadata = exposed_run_and_wait(update_user_metadata)

View File

@@ -16,6 +16,8 @@ from redis.lock import Lock as RedisLock
if TYPE_CHECKING:
from backend.executor import DatabaseManager
from autogpt_libs.utils.cache import thread_cached
from backend.data import redis
from backend.data.block import Block, BlockData, BlockInput, BlockType, get_block
from backend.data.execution import (
@@ -31,7 +33,6 @@ from backend.data.graph import Graph, Link, Node
from backend.data.model import CREDENTIALS_FIELD_NAME, CredentialsMetaInput
from backend.integrations.creds_manager import IntegrationCredentialsManager
from backend.util import json
from backend.util.cache import thread_cached_property
from backend.util.decorator import error_logged, time_measured
from backend.util.logging import configure_logging
from backend.util.process import set_service_name
@@ -104,6 +105,7 @@ def execute_node(
Args:
db_client: The client to send execution updates to the server.
creds_manager: The manager to acquire and release credentials.
data: The execution data for executing the current node.
execution_stats: The execution statistics to be updated.
@@ -209,6 +211,7 @@ def execute_node(
if creds_lock:
creds_lock.release()
if execution_stats is not None:
execution_stats.update(node_block.execution_stats)
execution_stats["input_size"] = input_size
execution_stats["output_size"] = output_size
@@ -657,20 +660,24 @@ class Executor:
class ExecutionManager(AppService):
def __init__(self):
super().__init__(port=settings.config.execution_manager_port)
super().__init__()
self.use_redis = True
self.use_supabase = True
self.pool_size = settings.config.num_graph_workers
self.queue = ExecutionQueue[GraphExecution]()
self.active_graph_runs: dict[str, tuple[Future, threading.Event]] = {}
@classmethod
def get_port(cls) -> int:
return settings.config.execution_manager_port
def run_service(self):
from autogpt_libs.supabase_integration_credentials_store import (
SupabaseIntegrationCredentialsStore,
)
self.credentials_store = SupabaseIntegrationCredentialsStore(
self.supabase, redis.get_redis()
redis=redis.get_redis()
)
self.executor = ProcessPoolExecutor(
max_workers=self.pool_size,
@@ -701,7 +708,7 @@ class ExecutionManager(AppService):
super().cleanup()
@thread_cached_property
@property
def db_client(self) -> "DatabaseManager":
return get_db_client()
@@ -857,10 +864,11 @@ class ExecutionManager(AppService):
# ------- UTILITIES ------- #
@thread_cached
def get_db_client() -> "DatabaseManager":
from backend.executor import DatabaseManager
return get_service_client(DatabaseManager, settings.config.database_api_port)
return get_service_client(DatabaseManager)
@contextmanager

View File

@@ -4,6 +4,7 @@ from datetime import datetime
from apscheduler.schedulers.background import BackgroundScheduler
from apscheduler.triggers.cron import CronTrigger
from autogpt_libs.utils.cache import thread_cached_property
from backend.data.block import BlockInput
from backend.data.schedule import (
@@ -14,7 +15,6 @@ from backend.data.schedule import (
update_schedule,
)
from backend.executor.manager import ExecutionManager
from backend.util.cache import thread_cached_property
from backend.util.service import AppService, expose, get_service_client
from backend.util.settings import Config
@@ -28,14 +28,18 @@ def log(msg, **kwargs):
class ExecutionScheduler(AppService):
def __init__(self, refresh_interval=10):
super().__init__(port=Config().execution_scheduler_port)
super().__init__()
self.use_db = True
self.last_check = datetime.min
self.refresh_interval = refresh_interval
@classmethod
def get_port(cls) -> int:
return Config().execution_scheduler_port
@thread_cached_property
def execution_client(self) -> ExecutionManager:
return get_service_client(ExecutionManager, Config().execution_manager_port)
return get_service_client(ExecutionManager)
def run_service(self):
scheduler = BackgroundScheduler()

View File

@@ -13,8 +13,6 @@ from backend.data import redis
from backend.integrations.oauth import HANDLERS_BY_NAME, BaseOAuthHandler
from backend.util.settings import Settings
from ..server.integrations.utils import get_supabase
logger = logging.getLogger(__name__)
settings = Settings()
@@ -54,7 +52,7 @@ class IntegrationCredentialsManager:
def __init__(self):
redis_conn = redis.get_redis()
self._locks = RedisKeyedMutex(redis_conn)
self.store = SupabaseIntegrationCredentialsStore(get_supabase(), redis_conn)
self.store = SupabaseIntegrationCredentialsStore(redis=redis_conn)
def create(self, user_id: str, credentials: Credentials) -> None:
return self.store.add_creds(user_id, credentials)
@@ -131,7 +129,7 @@ class IntegrationCredentialsManager:
def _acquire_lock(self, user_id: str, credentials_id: str, *args: str) -> RedisLock:
key = (
self.store.supabase.supabase_url,
self.store.db_manager,
f"user:{user_id}",
f"credentials:{credentials_id}",
*args,

View File

@@ -19,6 +19,7 @@ from ..utils import get_user_id
logger = logging.getLogger(__name__)
settings = Settings()
router = APIRouter()
creds_manager = IntegrationCredentialsManager()
@@ -41,7 +42,7 @@ async def login(
requested_scopes = scopes.split(",") if scopes else []
# Generate and store a secure random state token along with the scopes
state_token = await creds_manager.store.store_state_token(
state_token = creds_manager.store.store_state_token(
user_id, provider, requested_scopes
)
@@ -70,12 +71,12 @@ async def callback(
handler = _get_provider_oauth_handler(request, provider)
# Verify the state token
if not await creds_manager.store.verify_state_token(user_id, state_token, provider):
if not creds_manager.store.verify_state_token(user_id, state_token, provider):
logger.warning(f"Invalid or expired state token for user {user_id}")
raise HTTPException(status_code=400, detail="Invalid or expired state token")
try:
scopes = await creds_manager.store.get_any_valid_scopes_from_state_token(
scopes = creds_manager.store.get_any_valid_scopes_from_state_token(
user_id, state_token, provider
)
logger.debug(f"Retrieved scopes from state token: {scopes}")

View File

@@ -7,6 +7,7 @@ from typing import Annotated, Any, Dict
import uvicorn
from autogpt_libs.auth.middleware import auth_middleware
from autogpt_libs.utils.cache import thread_cached_property
from fastapi import APIRouter, Body, Depends, FastAPI, HTTPException, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
@@ -19,9 +20,7 @@ from backend.data.block import BlockInput, CompletedBlockOutput
from backend.data.credit import get_block_costs, get_user_credit_model
from backend.data.user import get_or_create_user
from backend.executor import ExecutionManager, ExecutionScheduler
from backend.integrations.creds_manager import IntegrationCredentialsManager
from backend.server.model import CreateGraph, SetGraphActiveVersion
from backend.util.cache import thread_cached_property
from backend.util.service import AppService, get_service_client
from backend.util.settings import AppEnvironment, Config, Settings
@@ -36,9 +35,13 @@ class AgentServer(AppService):
_user_credit_model = get_user_credit_model()
def __init__(self):
super().__init__(port=Config().agent_server_port)
super().__init__()
self.use_redis = True
@classmethod
def get_port(cls) -> int:
return Config().agent_server_port
@asynccontextmanager
async def lifespan(self, _: FastAPI):
await db.connect()
@@ -97,7 +100,6 @@ class AgentServer(AppService):
tags=["integrations"],
dependencies=[Depends(auth_middleware)],
)
self.integration_creds_manager = IntegrationCredentialsManager()
api_router.include_router(
backend.server.routers.analytics.router,
@@ -307,11 +309,11 @@ class AgentServer(AppService):
@thread_cached_property
def execution_manager_client(self) -> ExecutionManager:
return get_service_client(ExecutionManager, Config().execution_manager_port)
return get_service_client(ExecutionManager)
@thread_cached_property
def execution_scheduler_client(self) -> ExecutionScheduler:
return get_service_client(ExecutionScheduler, Config().execution_scheduler_port)
return get_service_client(ExecutionScheduler)
@classmethod
def handle_internal_http_error(cls, request: Request, exc: Exception):
@@ -330,9 +332,9 @@ class AgentServer(AppService):
@classmethod
def get_graph_blocks(cls) -> list[dict[Any, Any]]:
blocks = block.get_blocks()
blocks = [cls() for cls in block.get_blocks().values()]
costs = get_block_costs()
return [{**b.to_dict(), "costs": costs.get(b.id, [])} for b in blocks.values()]
return [{**b.to_dict(), "costs": costs.get(b.id, [])} for b in blocks]
@classmethod
def execute_graph_block(

View File

@@ -28,7 +28,7 @@ async def lifespan(app: FastAPI):
docs_url = "/docs" if settings.config.app_env == AppEnvironment.LOCAL else None
app = FastAPI(lifespan=lifespan)
app = FastAPI(lifespan=lifespan, docs_url=docs_url)
_connection_manager = None
logger.info(f"CORS allow origins: {settings.config.backend_cors_allow_origins}")
@@ -66,7 +66,7 @@ async def event_broadcaster(manager: ConnectionManager):
async def authenticate_websocket(websocket: WebSocket) -> str:
if settings.config.enable_auth.lower() == "true":
if settings.config.enable_auth:
token = websocket.query_params.get("token")
if not token:
await websocket.close(code=4001, reason="Missing authentication token")

View File

@@ -1,21 +0,0 @@
import threading
from functools import wraps
from typing import Callable, TypeVar
T = TypeVar("T")
R = TypeVar("R")
def thread_cached_property(func: Callable[[T], R]) -> property:
local_cache = threading.local()
@wraps(func)
def wrapper(self: T) -> R:
if not hasattr(local_cache, "cache"):
local_cache.cache = {}
key = id(self)
if key not in local_cache.cache:
local_cache.cache[key] = func(self)
return local_cache.cache[key]
return property(wrapper)

View File

@@ -5,6 +5,7 @@ import os
import threading
import time
import typing
from abc import ABC, abstractmethod
from enum import Enum
from types import NoneType, UnionType
from typing import (
@@ -99,16 +100,24 @@ def _make_custom_deserializer(model: Type[BaseModel]):
return custom_dict_to_class
class AppService(AppProcess):
class AppService(AppProcess, ABC):
shared_event_loop: asyncio.AbstractEventLoop
use_db: bool = False
use_redis: bool = False
use_supabase: bool = False
def __init__(self, port):
self.port = port
def __init__(self):
self.uri = None
@classmethod
@abstractmethod
def get_port(cls) -> int:
pass
@classmethod
def get_host(cls) -> str:
return os.environ.get(f"{cls.service_name.upper()}_HOST", Config().pyro_host)
def run_service(self) -> None:
while True:
time.sleep(10)
@@ -157,8 +166,7 @@ class AppService(AppProcess):
@conn_retry("Pyro", "Starting Pyro Service")
def __start_pyro(self):
host = Config().pyro_host
daemon = Pyro5.api.Daemon(host=host, port=self.port)
daemon = Pyro5.api.Daemon(host=self.get_host(), port=self.get_port())
self.uri = daemon.register(self, objectId=self.service_name)
logger.info(f"[{self.service_name}] Connected to Pyro; URI = {self.uri}")
daemon.requestLoop()
@@ -167,16 +175,20 @@ class AppService(AppProcess):
self.shared_event_loop.run_forever()
# --------- UTILITIES --------- #
AS = TypeVar("AS", bound=AppService)
def get_service_client(service_type: Type[AS], port: int) -> AS:
def get_service_client(service_type: Type[AS]) -> AS:
service_name = service_type.service_name
class DynamicClient:
@conn_retry("Pyro", f"Connecting to [{service_name}]")
def __init__(self):
host = os.environ.get(f"{service_name.upper()}_HOST", "localhost")
host = service_type.get_host()
port = service_type.get_port()
uri = f"PYRO:{service_type.service_name}@{host}:{port}"
logger.debug(f"Connecting to service [{service_name}]. URI = {uri}")
self.proxy = Pyro5.api.Proxy(uri)
@@ -191,8 +203,6 @@ def get_service_client(service_type: Type[AS], port: int) -> AS:
return cast(AS, DynamicClient())
# --------- UTILITIES --------- #
builtin_types = [*vars(builtins).values(), NoneType, Enum]

View File

@@ -69,8 +69,8 @@ class Config(UpdateTrackingModel["Config"], BaseSettings):
default="localhost",
description="The default hostname of the Pyro server.",
)
enable_auth: str = Field(
default="false",
enable_auth: bool = Field(
default=True,
description="If authentication is enabled or not",
)
enable_credit: str = Field(
@@ -133,7 +133,7 @@ class Config(UpdateTrackingModel["Config"], BaseSettings):
)
frontend_base_url: str = Field(
default="",
default="http://localhost:3000",
description="Can be used to explicitly set the base URL for the frontend. "
"This value is then used to generate redirect URLs for OAuth flows.",
)

View File

@@ -31,7 +31,7 @@ class SpinTestServer:
await db.connect()
await initialize_blocks()
await create_default_user("false")
await create_default_user()
return self

View File

@@ -0,0 +1,2 @@
-- AlterTable
ALTER TABLE "User" ADD COLUMN "metadata" JSONB;

View File

@@ -0,0 +1,27 @@
--CreateFunction
CREATE OR REPLACE FUNCTION add_user_to_platform() RETURNS TRIGGER AS $$
BEGIN
INSERT INTO platform."User" (id, email, "updatedAt")
VALUES (NEW.id, NEW.email, now());
RETURN NEW;
END;
$$ LANGUAGE plpgsql SECURITY DEFINER;
DO $$
BEGIN
-- Check if the auth schema and users table exist
IF EXISTS (
SELECT 1
FROM information_schema.tables
WHERE table_schema = 'auth'
AND table_name = 'users'
) THEN
-- Drop the trigger if it exists
DROP TRIGGER IF EXISTS user_added_to_platform ON auth.users;
-- Create the trigger
CREATE TRIGGER user_added_to_platform
AFTER INSERT ON auth.users
FOR EACH ROW EXECUTE FUNCTION add_user_to_platform();
END IF;
END $$;

View File

@@ -17,6 +17,7 @@ model User {
name String?
createdAt DateTime @default(now())
updatedAt DateTime @updatedAt
metadata Json?
// Relations
AgentGraphs AgentGraph[]

View File

@@ -0,0 +1,3 @@
import os
os.environ["ENABLE_AUTH"] = "false"

View File

@@ -1,3 +1,5 @@
from typing import Type
import pytest
from backend.data.block import Block, get_blocks
@@ -5,5 +7,5 @@ from backend.util.test import execute_block_test
@pytest.mark.parametrize("block", get_blocks().values(), ids=lambda b: b.name)
def test_available_blocks(block: Block):
execute_block_test(type(block)())
def test_available_blocks(block: Type[Block]):
execute_block_test(block())

View File

@@ -5,7 +5,6 @@ from backend.executor import ExecutionScheduler
from backend.server.model import CreateGraph
from backend.usecases.sample import create_test_graph, create_test_user
from backend.util.service import get_service_client
from backend.util.settings import Config
from backend.util.test import SpinTestServer
@@ -19,10 +18,7 @@ async def test_agent_schedule(server: SpinTestServer):
user_id=test_user.id,
)
scheduler = get_service_client(
ExecutionScheduler, Config().execution_scheduler_port
)
scheduler = get_service_client(ExecutionScheduler)
schedules = scheduler.get_execution_schedules(test_graph.id, test_user.id)
assert len(schedules) == 0

View File

@@ -7,7 +7,11 @@ TEST_SERVICE_PORT = 8765
class ServiceTest(AppService):
def __init__(self):
super().__init__(port=TEST_SERVICE_PORT)
super().__init__()
@classmethod
def get_port(cls) -> int:
return TEST_SERVICE_PORT
@expose
def add(self, a: int, b: int) -> int:
@@ -28,7 +32,7 @@ class ServiceTest(AppService):
@pytest.mark.asyncio(scope="session")
async def test_service_creation(server):
with ServiceTest():
client = get_service_client(ServiceTest, TEST_SERVICE_PORT)
client = get_service_client(ServiceTest)
assert client.add(5, 3) == 8
assert client.subtract(10, 4) == 6
assert client.fun_with_async(5, 3) == 8

View File

@@ -8,7 +8,7 @@ services:
develop:
watch:
- path: ./
target: autogpt_platform/backend/migrate
target: autogpt_platform/backend/migrations
action: rebuild
depends_on:
db:

View File

@@ -96,6 +96,36 @@ services:
file: ./supabase/docker/docker-compose.yml
service: rest
realtime:
<<: *supabase-services
extends:
file: ./supabase/docker/docker-compose.yml
service: realtime
storage:
<<: *supabase-services
extends:
file: ./supabase/docker/docker-compose.yml
service: storage
imgproxy:
<<: *supabase-services
extends:
file: ./supabase/docker/docker-compose.yml
service: imgproxy
meta:
<<: *supabase-services
extends:
file: ./supabase/docker/docker-compose.yml
service: meta
functions:
<<: *supabase-services
extends:
file: ./supabase/docker/docker-compose.yml
service: functions
analytics:
<<: *supabase-services
extends:
@@ -112,3 +142,24 @@ services:
extends:
file: ./supabase/docker/docker-compose.yml
service: vector
deps:
<<: *supabase-services
profiles:
- local
image: busybox
command: /bin/true
depends_on:
- studio
- kong
- auth
- rest
- realtime
- storage
- imgproxy
- meta
- functions
- analytics
- db
- vector
- redis

View File

@@ -144,7 +144,7 @@ const SubmitPage: React.FC = () => {
setSubmitError(null);
if (!data.agreeToTerms) {
throw new Error("You must agree to the terms of service");
throw new Error("You must agree to the terms of use");
}
try {
@@ -404,7 +404,7 @@ const SubmitPage: React.FC = () => {
<Controller
name="agreeToTerms"
control={control}
rules={{ required: "You must agree to the terms of service" }}
rules={{ required: "You must agree to the terms of use" }}
render={({ field }) => (
<div className="flex items-center space-x-2">
<Checkbox
@@ -417,8 +417,11 @@ const SubmitPage: React.FC = () => {
className="text-sm font-medium leading-none peer-disabled:cursor-not-allowed peer-disabled:opacity-70"
>
I agree to the{" "}
<a href="/terms" className="text-blue-500 hover:underline">
terms of service
<a
href="https://auto-gpt.notion.site/Terms-of-Use-11400ef5bece80d0b087d7831c5fd6bf"
className="text-blue-500 hover:underline"
>
terms of use
</a>
</label>
</div>

View File

@@ -1,9 +1,9 @@
# dev values, overwrite base values as needed.
image:
repository: us-east1-docker.pkg.dev/agpt-dev/agpt-builder-dev/agpt-builder-dev
repository: us-east1-docker.pkg.dev/agpt-dev/agpt-frontend-dev/agpt-frontend-dev
pullPolicy: Always
tag: "fe3d2a9"
tag: "latest"
serviceAccount:
annotations:

View File

@@ -1,7 +1,7 @@
# dev values, overwrite base values as needed.
image:
repository: us-east1-docker.pkg.dev/agpt-dev/agpt-server-dev/agpt-server-dev
repository: us-east1-docker.pkg.dev/agpt-dev/agpt-backend-dev/agpt-backend-dev
pullPolicy: Always
tag: "latest"

View File

@@ -1,7 +1,7 @@
replicaCount: 1 # not scaling websocket server for now
image:
repository: us-east1-docker.pkg.dev/agpt-dev/agpt-server-dev/agpt-server-dev
repository: us-east1-docker.pkg.dev/agpt-dev/agpt-backend-dev/agpt-backend-dev
tag: latest
pullPolicy: Always

View File

@@ -28,6 +28,10 @@ service_accounts = {
"dev-agpt-market-sa" = {
display_name = "AutoGPT Dev Market Server Account"
description = "Service account for agpt dev market server"
},
"dev-github-actions-sa" = {
display_name = "GitHub Actions Dev Service Account"
description = "Service account for GitHub Actions deployments to dev"
}
}
@@ -51,6 +55,11 @@ workload_identity_bindings = {
service_account_name = "dev-agpt-market-sa"
namespace = "dev-agpt"
ksa_name = "dev-agpt-market-sa"
},
"dev-github-actions-workload-identity" = {
service_account_name = "dev-github-actions-sa"
namespace = "dev-agpt"
ksa_name = "dev-github-actions-sa"
}
}
@@ -59,7 +68,8 @@ role_bindings = {
"serviceAccount:dev-agpt-server-sa@agpt-dev.iam.gserviceaccount.com",
"serviceAccount:dev-agpt-builder-sa@agpt-dev.iam.gserviceaccount.com",
"serviceAccount:dev-agpt-ws-server-sa@agpt-dev.iam.gserviceaccount.com",
"serviceAccount:dev-agpt-market-sa@agpt-dev.iam.gserviceaccount.com"
"serviceAccount:dev-agpt-market-sa@agpt-dev.iam.gserviceaccount.com",
"serviceAccount:dev-github-actions-sa@agpt-dev.iam.gserviceaccount.com"
],
"roles/cloudsql.client" = [
"serviceAccount:dev-agpt-server-sa@agpt-dev.iam.gserviceaccount.com",
@@ -80,7 +90,8 @@ role_bindings = {
"serviceAccount:dev-agpt-server-sa@agpt-dev.iam.gserviceaccount.com",
"serviceAccount:dev-agpt-builder-sa@agpt-dev.iam.gserviceaccount.com",
"serviceAccount:dev-agpt-ws-server-sa@agpt-dev.iam.gserviceaccount.com",
"serviceAccount:dev-agpt-market-sa@agpt-dev.iam.gserviceaccount.com"
"serviceAccount:dev-agpt-market-sa@agpt-dev.iam.gserviceaccount.com",
"serviceAccount:dev-github-actions-sa@agpt-dev.iam.gserviceaccount.com"
]
"roles/compute.networkUser" = [
"serviceAccount:dev-agpt-server-sa@agpt-dev.iam.gserviceaccount.com",
@@ -93,6 +104,16 @@ role_bindings = {
"serviceAccount:dev-agpt-builder-sa@agpt-dev.iam.gserviceaccount.com",
"serviceAccount:dev-agpt-ws-server-sa@agpt-dev.iam.gserviceaccount.com",
"serviceAccount:dev-agpt-market-sa@agpt-dev.iam.gserviceaccount.com"
],
"roles/artifactregistry.writer" = [
"serviceAccount:dev-github-actions-sa@agpt-dev.iam.gserviceaccount.com"
],
"roles/container.viewer" = [
"serviceAccount:dev-github-actions-sa@agpt-dev.iam.gserviceaccount.com"
],
"roles/iam.serviceAccountTokenCreator" = [
"principalSet://iam.googleapis.com/projects/638488734936/locations/global/workloadIdentityPools/dev-pool/*",
"serviceAccount:dev-github-actions-sa@agpt-dev.iam.gserviceaccount.com"
]
}
@@ -101,4 +122,25 @@ services_ip_cidr_range = "10.2.0.0/20"
public_bucket_names = ["website-artifacts"]
standard_bucket_names = []
bucket_admins = ["gcp-devops-agpt@agpt.co", "gcp-developers@agpt.co"]
bucket_admins = ["gcp-devops-agpt@agpt.co", "gcp-developers@agpt.co"]
workload_identity_pools = {
"dev-pool" = {
display_name = "Development Identity Pool"
providers = {
"github" = {
issuer_uri = "https://token.actions.githubusercontent.com"
attribute_mapping = {
"google.subject" = "assertion.sub"
"attribute.repository" = "assertion.repository"
"attribute.repository_owner" = "assertion.repository_owner"
}
}
}
service_accounts = {
"dev-github-actions-sa" = [
"Significant-Gravitas/AutoGPT"
]
}
}
}

View File

@@ -28,6 +28,11 @@ service_accounts = {
"prod-agpt-market-sa" = {
display_name = "AutoGPT prod Market backend Account"
description = "Service account for agpt prod market backend"
},
"prod-github-actions-workload-identity" = {
service_account_name = "prod-github-actions-sa"
namespace = "prod-agpt"
ksa_name = "prod-github-actions-sa"
}
}
@@ -59,7 +64,8 @@ role_bindings = {
"serviceAccount:prod-agpt-backend-sa@agpt-prod.iam.gserviceaccount.com",
"serviceAccount:prod-agpt-frontend-sa@agpt-prod.iam.gserviceaccount.com",
"serviceAccount:prod-agpt-ws-backend-sa@agpt-prod.iam.gserviceaccount.com",
"serviceAccount:prod-agpt-market-sa@agpt-prod.iam.gserviceaccount.com"
"serviceAccount:prod-agpt-market-sa@agpt-prod.iam.gserviceaccount.com",
"serviceAccount:prod-github-actions-sa@agpt-prod.iam.gserviceaccount.com"
],
"roles/cloudsql.client" = [
"serviceAccount:prod-agpt-backend-sa@agpt-prod.iam.gserviceaccount.com",
@@ -80,7 +86,8 @@ role_bindings = {
"serviceAccount:prod-agpt-backend-sa@agpt-prod.iam.gserviceaccount.com",
"serviceAccount:prod-agpt-frontend-sa@agpt-prod.iam.gserviceaccount.com",
"serviceAccount:prod-agpt-ws-backend-sa@agpt-prod.iam.gserviceaccount.com",
"serviceAccount:prod-agpt-market-sa@agpt-prod.iam.gserviceaccount.com"
"serviceAccount:prod-agpt-market-sa@agpt-prod.iam.gserviceaccount.com",
"serviceAccount:prod-github-actions-sa@agpt-prod.iam.gserviceaccount.com"
]
"roles/compute.networkUser" = [
"serviceAccount:prod-agpt-backend-sa@agpt-prod.iam.gserviceaccount.com",
@@ -93,6 +100,16 @@ role_bindings = {
"serviceAccount:prod-agpt-frontend-sa@agpt-prod.iam.gserviceaccount.com",
"serviceAccount:prod-agpt-ws-backend-sa@agpt-prod.iam.gserviceaccount.com",
"serviceAccount:prod-agpt-market-sa@agpt-prod.iam.gserviceaccount.com"
],
"roles/artifactregistry.writer" = [
"serviceAccount:prod-github-actions-sa@agpt-prod.iam.gserviceaccount.com"
],
"roles/container.viewer" = [
"serviceAccount:prod-github-actions-sa@agpt-prod.iam.gserviceaccount.com"
],
"roles/iam.serviceAccountTokenCreator" = [
"principalSet://iam.googleapis.com/projects/638488734936/locations/global/workloadIdentityPools/prod-pool/*",
"serviceAccount:prod-github-actions-sa@agpt-prod.iam.gserviceaccount.com"
]
}
@@ -101,4 +118,25 @@ services_ip_cidr_range = "10.2.0.0/20"
public_bucket_names = ["website-artifacts"]
standard_bucket_names = []
bucket_admins = ["gcp-devops-agpt@agpt.co", "gcp-developers@agpt.co"]
bucket_admins = ["gcp-devops-agpt@agpt.co", "gcp-developers@agpt.co"]
workload_identity_pools = {
"dev-pool" = {
display_name = "Production Identity Pool"
providers = {
"github" = {
issuer_uri = "https://token.actions.githubusercontent.com"
attribute_mapping = {
"google.subject" = "assertion.sub"
"attribute.repository" = "assertion.repository"
"attribute.repository_owner" = "assertion.repository_owner"
}
}
}
service_accounts = {
"prod-github-actions-sa" = [
"Significant-Gravitas/AutoGPT"
]
}
}
}

View File

@@ -61,6 +61,7 @@ module "iam" {
service_accounts = var.service_accounts
workload_identity_bindings = var.workload_identity_bindings
role_bindings = var.role_bindings
workload_identity_pools = var.workload_identity_pools
}
module "storage" {

View File

@@ -23,4 +23,31 @@ resource "google_project_iam_binding" "role_bindings" {
role = each.key
members = each.value
}
resource "google_iam_workload_identity_pool" "pools" {
for_each = var.workload_identity_pools
workload_identity_pool_id = each.key
display_name = each.value.display_name
}
resource "google_iam_workload_identity_pool_provider" "providers" {
for_each = merge([
for pool_id, pool in var.workload_identity_pools : {
for provider_id, provider in pool.providers :
"${pool_id}/${provider_id}" => merge(provider, {
pool_id = pool_id
})
}
]...)
workload_identity_pool_id = split("/", each.key)[0]
workload_identity_pool_provider_id = split("/", each.key)[1]
attribute_mapping = each.value.attribute_mapping
oidc {
issuer_uri = each.value.issuer_uri
allowed_audiences = each.value.allowed_audiences
}
attribute_condition = "assertion.repository_owner==\"Significant-Gravitas\""
}

View File

@@ -1,4 +1,14 @@
output "service_account_emails" {
description = "The emails of the created service accounts"
value = { for k, v in google_service_account.service_accounts : k => v.email }
}
}
output "workload_identity_pools" {
value = google_iam_workload_identity_pool.pools
}
output "workload_identity_providers" {
value = {
for k, v in google_iam_workload_identity_pool_provider.providers : k => v.name
}
}

View File

@@ -26,4 +26,17 @@ variable "role_bindings" {
description = "Map of roles to list of members"
type = map(list(string))
default = {}
}
variable "workload_identity_pools" {
type = map(object({
display_name = string
providers = map(object({
issuer_uri = string
attribute_mapping = map(string)
allowed_audiences = optional(list(string))
}))
service_accounts = map(list(string)) # Map of SA to list of allowed principals
}))
default = {}
}

View File

@@ -130,3 +130,19 @@ variable "bucket_admins" {
default = ["gcp-devops-agpt@agpt.co", "gcp-developers@agpt.co"]
}
variable "workload_identity_pools" {
type = map(object({
display_name = string
providers = map(object({
issuer_uri = string
attribute_mapping = map(string)
allowed_audiences = optional(list(string))
}))
service_accounts = map(list(string))
}))
default = {}
description = "Configuration for workload identity pools and their providers"
}