mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-01-08 22:58:01 -05:00
feat(platform): OAuth support + API key management + GitHub blocks (#8044)
## Config
- For Supabase, the back end needs `SUPABASE_URL`, `SUPABASE_SERVICE_ROLE_KEY`, and `SUPABASE_JWT_SECRET`
- For the GitHub integration to work, the back end needs `GITHUB_CLIENT_ID` and `GITHUB_CLIENT_SECRET`
- For integrations OAuth flows to work in local development, the back end needs `FRONTEND_BASE_URL` to generate login URLs with accurate redirect URLs
## REST API
- Tweak output of OAuth `/login` endpoint: add `state_token` separately in response
- Add `POST /integrations/{provider}/credentials` (for API keys)
- Add `DELETE /integrations/{provider}/credentials/{cred_id}`
## Back end
- Add Supabase support to `AppService`
- Add `FRONTEND_BASE_URL` config option, mainly for local development use
### `autogpt_libs.supabase_integration_credentials_store`
- Add `CredentialsType` alias
- Add `.bearer()` helper methods to `APIKeyCredentials` and `OAuth2Credentials`
### Blocks
- Add `CredentialsField(..) -> CredentialsMetaInput`
## Front end
### UI components
- `CredentialsInput` for use on `CustomNode`: allows user to add/select credentials for a service.
- `APIKeyCredentialsModal`: a dialog for creating API keys
- `OAuth2FlowWaitingModal`: a dialog to indicate that the application is waiting for the user to log in to the 3rd party service in the provided pop-up window
- `NodeCredentialsInput`: wrapper for `CredentialsInput` with the "usual" interface of node input components
- New icons: `IconKey`, `IconKeyPlus`, `IconUser`, `IconUserPlus`
### Data model
- `CredentialsProvider`: introduces the app-level `CredentialsProvidersContext`, which acts as an application-wide store and cache for credentials metadata.
- `useCredentials` for use on `CustomNode`: uses `CredentialsProvidersContext` and provides node-specific credential data and provider-specific data/functions
- `/auth/integrations/oauth_callback` route to close the loop to the `CredentialsInput` after a user completes sign-in to the external service
- Add `BlockIOCredentialsSubSchema`
### API client
- Add `isAuthenticated` method
- Add methods for integration OAuth flow: `oAuthLogin`, `oAuthCallback`
- Add CRD methods for credentials: `createAPIKeyCredentials`, `listCredentials`, `getCredentials`, `deleteCredentials`
- Add mirrored types `CredentialsMetaResponse`, `CredentialsMetaInput`, `OAuth2Credentials`, `APIKeyCredentials`
- Add GitHub blocks + "DEVELOPER_TOOLS" category
- Add `**kwargs` to `Block.run(..)` signature to support additional kwargs
- Add support for loading blocks from nested modules (e.g. `blocks/github/issues.py`)
#### Executor
- Add strict support for `credentials` fields on blocks
- Fetch credentials for graph execution and pass them down through to the node execution
This commit is contained in:
committed by
GitHub
parent
3a1574e4bd
commit
5e2874c315
2
.github/workflows/platform-backend-ci.yml
vendored
2
.github/workflows/platform-backend-ci.yml
vendored
@@ -105,7 +105,7 @@ jobs:
|
||||
LOG_LEVEL: ${{ runner.debug && 'DEBUG' || 'INFO' }}
|
||||
DATABASE_URL: ${{ steps.supabase.outputs.DB_URL }}
|
||||
SUPABASE_URL: ${{ steps.supabase.outputs.API_URL }}
|
||||
SUPABASE_SERVICE_KEY: ${{ steps.supabase.outputs.SERVICE_ROLE_KEY }}
|
||||
SUPABASE_SERVICE_ROLE_KEY: ${{ steps.supabase.outputs.SERVICE_ROLE_KEY }}
|
||||
SUPABASE_JWT_SECRET: ${{ steps.supabase.outputs.JWT_SECRET }}
|
||||
env:
|
||||
CI: true
|
||||
|
||||
@@ -7,12 +7,13 @@ from .config import settings
|
||||
from .jwt_utils import parse_jwt_token
|
||||
|
||||
security = HTTPBearer()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def auth_middleware(request: Request):
|
||||
if not settings.ENABLE_AUTH:
|
||||
# If authentication is disabled, allow the request to proceed
|
||||
logging.warn("Auth disabled")
|
||||
logger.warn("Auth disabled")
|
||||
return {}
|
||||
|
||||
security = HTTPBearer()
|
||||
@@ -24,7 +25,7 @@ async def auth_middleware(request: Request):
|
||||
try:
|
||||
payload = parse_jwt_token(credentials.credentials)
|
||||
request.state.user = payload
|
||||
logging.info("Token decoded successfully")
|
||||
logger.debug("Token decoded successfully")
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=401, detail=str(e))
|
||||
return payload
|
||||
|
||||
@@ -29,6 +29,9 @@ class OAuth2Credentials(_BaseCredentials):
|
||||
scopes: list[str]
|
||||
metadata: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
def bearer(self) -> str:
|
||||
return f"Bearer {self.access_token.get_secret_value()}"
|
||||
|
||||
|
||||
class APIKeyCredentials(_BaseCredentials):
|
||||
type: Literal["api_key"] = "api_key"
|
||||
@@ -36,6 +39,9 @@ class APIKeyCredentials(_BaseCredentials):
|
||||
expires_at: Optional[int]
|
||||
"""Unix timestamp (seconds) indicating when the API key expires (if at all)"""
|
||||
|
||||
def bearer(self) -> str:
|
||||
return f"Bearer {self.api_key.get_secret_value()}"
|
||||
|
||||
|
||||
Credentials = Annotated[
|
||||
OAuth2Credentials | APIKeyCredentials,
|
||||
@@ -43,6 +49,9 @@ Credentials = Annotated[
|
||||
]
|
||||
|
||||
|
||||
CredentialsType = Literal["api_key", "oauth2"]
|
||||
|
||||
|
||||
class OAuthState(BaseModel):
|
||||
token: str
|
||||
provider: str
|
||||
|
||||
@@ -11,13 +11,30 @@ REDIS_HOST=localhost
|
||||
REDIS_PORT=6379
|
||||
REDIS_PASSWORD=password
|
||||
|
||||
ENABLE_AUTH=false
|
||||
ENABLE_CREDIT=false
|
||||
APP_ENV="local"
|
||||
PYRO_HOST=localhost
|
||||
SENTRY_DSN=
|
||||
# This is needed when ENABLE_AUTH is true
|
||||
SUPABASE_JWT_SECRET=our-super-secret-jwt-token-with-at-least-32-characters-long
|
||||
|
||||
## User auth with Supabase is required for any of the 3rd party integrations with auth to work.
|
||||
ENABLE_AUTH=false
|
||||
SUPABASE_URL=
|
||||
SUPABASE_SERVICE_ROLE_KEY=
|
||||
SUPABASE_JWT_SECRET=
|
||||
|
||||
# For local development, you may need to set FRONTEND_BASE_URL for the OAuth flow for integrations to work.
|
||||
# FRONTEND_BASE_URL=http://localhost:3000
|
||||
|
||||
## == INTEGRATION CREDENTIALS == ##
|
||||
# Each set of server side credentials is required for the corresponding 3rd party
|
||||
# integration to work.
|
||||
|
||||
# For the OAuth callback URL, use <your_frontend_url>/auth/integrations/oauth_callback,
|
||||
# e.g. http://localhost:3000/auth/integrations/oauth_callback
|
||||
|
||||
# GitHub OAuth App server credentials - https://github.com/settings/developers
|
||||
GITHUB_CLIENT_ID=
|
||||
GITHUB_CLIENT_SECRET=
|
||||
|
||||
## ===== OPTIONAL API KEYS ===== ##
|
||||
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import glob
|
||||
import importlib
|
||||
import os
|
||||
import re
|
||||
@@ -8,17 +7,17 @@ from backend.data.block import Block
|
||||
|
||||
# Dynamically load all modules under backend.blocks
|
||||
AVAILABLE_MODULES = []
|
||||
current_dir = os.path.dirname(__file__)
|
||||
modules = glob.glob(os.path.join(current_dir, "*.py"))
|
||||
current_dir = Path(__file__).parent
|
||||
modules = [
|
||||
Path(f).stem
|
||||
for f in modules
|
||||
if os.path.isfile(f) and f.endswith(".py") and not f.endswith("__init__.py")
|
||||
str(f.relative_to(current_dir))[:-3].replace(os.path.sep, ".")
|
||||
for f in current_dir.rglob("*.py")
|
||||
if f.is_file() and f.name != "__init__.py"
|
||||
]
|
||||
for module in modules:
|
||||
if not re.match("^[a-z_]+$", module):
|
||||
if not re.match("^[a-z_.]+$", module):
|
||||
raise ValueError(
|
||||
f"Block module {module} error: module name must be lowercase, separated by underscores, and contain only alphabet characters"
|
||||
f"Block module {module} error: module name must be lowercase, "
|
||||
"separated by underscores, and contain only alphabet characters"
|
||||
)
|
||||
|
||||
importlib.import_module(f".{module}", package=__name__)
|
||||
|
||||
@@ -57,7 +57,7 @@ class StoreValueBlock(Block):
|
||||
static_output=True,
|
||||
)
|
||||
|
||||
def run(self, input_data: Input) -> BlockOutput:
|
||||
def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
yield "output", input_data.data or input_data.input
|
||||
|
||||
|
||||
@@ -79,7 +79,7 @@ class PrintToConsoleBlock(Block):
|
||||
test_output=("status", "printed"),
|
||||
)
|
||||
|
||||
def run(self, input_data: Input) -> BlockOutput:
|
||||
def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
print(">>>>> Print: ", input_data.text)
|
||||
yield "status", "printed"
|
||||
|
||||
@@ -118,7 +118,7 @@ class FindInDictionaryBlock(Block):
|
||||
categories={BlockCategory.BASIC},
|
||||
)
|
||||
|
||||
def run(self, input_data: Input) -> BlockOutput:
|
||||
def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
obj = input_data.input
|
||||
key = input_data.key
|
||||
|
||||
@@ -200,7 +200,7 @@ class AgentInputBlock(Block):
|
||||
ui_type=BlockUIType.INPUT,
|
||||
)
|
||||
|
||||
def run(self, input_data: Input) -> BlockOutput:
|
||||
def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
yield "result", input_data.value
|
||||
|
||||
|
||||
@@ -283,7 +283,7 @@ class AgentOutputBlock(Block):
|
||||
ui_type=BlockUIType.OUTPUT,
|
||||
)
|
||||
|
||||
def run(self, input_data: Input) -> BlockOutput:
|
||||
def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
"""
|
||||
Attempts to format the recorded_value using the fmt_string if provided.
|
||||
If formatting fails or no fmt_string is given, returns the original recorded_value.
|
||||
@@ -343,7 +343,7 @@ class AddToDictionaryBlock(Block):
|
||||
],
|
||||
)
|
||||
|
||||
def run(self, input_data: Input) -> BlockOutput:
|
||||
def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
try:
|
||||
# If no dictionary is provided, create a new one
|
||||
if input_data.dictionary is None:
|
||||
@@ -414,7 +414,7 @@ class AddToListBlock(Block):
|
||||
],
|
||||
)
|
||||
|
||||
def run(self, input_data: Input) -> BlockOutput:
|
||||
def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
try:
|
||||
# If no list is provided, create a new one
|
||||
if input_data.list is None:
|
||||
@@ -455,5 +455,5 @@ class NoteBlock(Block):
|
||||
ui_type=BlockUIType.NOTE,
|
||||
)
|
||||
|
||||
def run(self, input_data: Input) -> BlockOutput:
|
||||
def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
yield "output", input_data.text
|
||||
|
||||
@@ -31,7 +31,7 @@ class BlockInstallationBlock(Block):
|
||||
disabled=True,
|
||||
)
|
||||
|
||||
def run(self, input_data: Input) -> BlockOutput:
|
||||
def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
code = input_data.code
|
||||
|
||||
if search := re.search(r"class (\w+)\(Block\):", code):
|
||||
|
||||
@@ -70,7 +70,7 @@ class ConditionBlock(Block):
|
||||
],
|
||||
)
|
||||
|
||||
def run(self, input_data: Input) -> BlockOutput:
|
||||
def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
value1 = input_data.value1
|
||||
operator = input_data.operator
|
||||
value2 = input_data.value2
|
||||
|
||||
@@ -40,7 +40,7 @@ class ReadCsvBlock(Block):
|
||||
],
|
||||
)
|
||||
|
||||
def run(self, input_data: Input) -> BlockOutput:
|
||||
def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
import csv
|
||||
from io import StringIO
|
||||
|
||||
|
||||
@@ -81,14 +81,14 @@ class ReadDiscordMessagesBlock(Block):
|
||||
|
||||
await client.start(token)
|
||||
|
||||
def run(self, input_data: "ReadDiscordMessagesBlock.Input") -> BlockOutput:
|
||||
def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
while True:
|
||||
for output_name, output_value in self.__run(input_data):
|
||||
yield output_name, output_value
|
||||
if not input_data.continuous_read:
|
||||
break
|
||||
|
||||
def __run(self, input_data: "ReadDiscordMessagesBlock.Input") -> BlockOutput:
|
||||
def __run(self, input_data: Input) -> BlockOutput:
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
future = self.run_bot(input_data.discord_bot_token.get_secret_value())
|
||||
@@ -187,7 +187,7 @@ class SendDiscordMessageBlock(Block):
|
||||
"""Splits a message into chunks not exceeding the Discord limit."""
|
||||
return [message[i : i + limit] for i in range(0, len(message), limit)]
|
||||
|
||||
def run(self, input_data: "SendDiscordMessageBlock.Input") -> BlockOutput:
|
||||
def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
future = self.send_message(
|
||||
|
||||
@@ -88,7 +88,7 @@ class SendEmailBlock(Block):
|
||||
except Exception as e:
|
||||
return f"Failed to send email: {str(e)}"
|
||||
|
||||
def run(self, input_data: Input) -> BlockOutput:
|
||||
def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
status = self.send_email(
|
||||
input_data.creds,
|
||||
input_data.to_email,
|
||||
|
||||
54
autogpt_platform/backend/backend/blocks/github/_auth.py
Normal file
54
autogpt_platform/backend/backend/blocks/github/_auth.py
Normal file
@@ -0,0 +1,54 @@
|
||||
from typing import Literal
|
||||
|
||||
from autogpt_libs.supabase_integration_credentials_store.types import (
|
||||
APIKeyCredentials,
|
||||
OAuth2Credentials,
|
||||
)
|
||||
from pydantic import SecretStr
|
||||
|
||||
from backend.data.model import CredentialsField, CredentialsMetaInput
|
||||
from backend.util.settings import Secrets
|
||||
|
||||
secrets = Secrets()
|
||||
GITHUB_OAUTH_IS_CONFIGURED = bool(
|
||||
secrets.github_client_id and secrets.github_client_secret
|
||||
)
|
||||
|
||||
GithubCredentials = APIKeyCredentials | OAuth2Credentials
|
||||
GithubCredentialsInput = CredentialsMetaInput[
|
||||
Literal["github"],
|
||||
Literal["api_key", "oauth2"] if GITHUB_OAUTH_IS_CONFIGURED else Literal["api_key"],
|
||||
]
|
||||
|
||||
|
||||
def GithubCredentialsField(scope: str) -> GithubCredentialsInput:
|
||||
"""
|
||||
Creates a GitHub credentials input on a block.
|
||||
|
||||
Params:
|
||||
scope: The authorization scope needed for the block to work. ([list of available scopes](https://docs.github.com/en/apps/oauth-apps/building-oauth-apps/scopes-for-oauth-apps#available-scopes))
|
||||
""" # noqa
|
||||
return CredentialsField(
|
||||
provider="github",
|
||||
supported_credential_types=(
|
||||
{"api_key", "oauth2"} if GITHUB_OAUTH_IS_CONFIGURED else {"api_key"}
|
||||
),
|
||||
required_scopes={scope},
|
||||
description="The GitHub integration can be used with OAuth, "
|
||||
"or any API key with sufficient permissions for the blocks it is used on.",
|
||||
)
|
||||
|
||||
|
||||
TEST_CREDENTIALS = APIKeyCredentials(
|
||||
id="01234567-89ab-cdef-0123-456789abcdef",
|
||||
provider="github",
|
||||
api_key=SecretStr("mock-github-api-key"),
|
||||
title="Mock GitHub API key",
|
||||
expires_at=None,
|
||||
)
|
||||
TEST_CREDENTIALS_INPUT = {
|
||||
"provider": TEST_CREDENTIALS.provider,
|
||||
"id": TEST_CREDENTIALS.id,
|
||||
"type": TEST_CREDENTIALS.type,
|
||||
"title": TEST_CREDENTIALS.type,
|
||||
}
|
||||
683
autogpt_platform/backend/backend/blocks/github/issues.py
Normal file
683
autogpt_platform/backend/backend/blocks/github/issues.py
Normal file
@@ -0,0 +1,683 @@
|
||||
import requests
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import SchemaField
|
||||
|
||||
from ._auth import (
|
||||
TEST_CREDENTIALS,
|
||||
TEST_CREDENTIALS_INPUT,
|
||||
GithubCredentials,
|
||||
GithubCredentialsField,
|
||||
GithubCredentialsInput,
|
||||
)
|
||||
|
||||
|
||||
class GithubCommentBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
issue_url: str = SchemaField(
|
||||
description="URL of the GitHub issue or pull request",
|
||||
placeholder="https://github.com/owner/repo/issues/1",
|
||||
)
|
||||
comment: str = SchemaField(
|
||||
description="Comment to post on the issue or pull request",
|
||||
placeholder="Enter your comment",
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
id: int = SchemaField(description="ID of the created comment")
|
||||
url: str = SchemaField(description="URL to the comment on GitHub")
|
||||
error: str = SchemaField(
|
||||
description="Error message if the comment posting failed"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="a8db4d8d-db1c-4a25-a1b0-416a8c33602b",
|
||||
description="This block posts a comment on a specified GitHub issue or pull request.",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=GithubCommentBlock.Input,
|
||||
output_schema=GithubCommentBlock.Output,
|
||||
test_input={
|
||||
"issue_url": "https://github.com/owner/repo/issues/1",
|
||||
"comment": "This is a test comment.",
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("id", 1337),
|
||||
("url", "https://github.com/owner/repo/issues/1#issuecomment-1337"),
|
||||
],
|
||||
test_mock={
|
||||
"post_comment": lambda *args, **kwargs: (
|
||||
1337,
|
||||
"https://github.com/owner/repo/issues/1#issuecomment-1337",
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def post_comment(
|
||||
credentials: GithubCredentials, issue_url: str, body_text: str
|
||||
) -> tuple[int, str]:
|
||||
if "/pull/" in issue_url:
|
||||
api_url = (
|
||||
issue_url.replace("github.com", "api.github.com/repos").replace(
|
||||
"/pull/", "/issues/"
|
||||
)
|
||||
+ "/comments"
|
||||
)
|
||||
else:
|
||||
api_url = (
|
||||
issue_url.replace("github.com", "api.github.com/repos") + "/comments"
|
||||
)
|
||||
|
||||
headers = {
|
||||
"Authorization": credentials.bearer(),
|
||||
"Accept": "application/vnd.github.v3+json",
|
||||
}
|
||||
data = {"body": body_text}
|
||||
|
||||
response = requests.post(api_url, headers=headers, json=data)
|
||||
response.raise_for_status()
|
||||
|
||||
comment = response.json()
|
||||
return comment["id"], comment["html_url"]
|
||||
|
||||
def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: GithubCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
id, url = self.post_comment(
|
||||
credentials,
|
||||
input_data.issue_url,
|
||||
input_data.comment,
|
||||
)
|
||||
yield "id", id
|
||||
yield "url", url
|
||||
except Exception as e:
|
||||
yield "error", f"Failed to post comment: {str(e)}"
|
||||
|
||||
|
||||
class GithubMakeIssueBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
repo_url: str = SchemaField(
|
||||
description="URL of the GitHub repository",
|
||||
placeholder="https://github.com/owner/repo",
|
||||
)
|
||||
title: str = SchemaField(
|
||||
description="Title of the issue", placeholder="Enter the issue title"
|
||||
)
|
||||
body: str = SchemaField(
|
||||
description="Body of the issue", placeholder="Enter the issue body"
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
number: int = SchemaField(description="Number of the created issue")
|
||||
url: str = SchemaField(description="URL of the created issue")
|
||||
error: str = SchemaField(
|
||||
description="Error message if the issue creation failed"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="691dad47-f494-44c3-a1e8-05b7990f2dab",
|
||||
description="This block creates a new issue on a specified GitHub repository.",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=GithubMakeIssueBlock.Input,
|
||||
output_schema=GithubMakeIssueBlock.Output,
|
||||
test_input={
|
||||
"repo_url": "https://github.com/owner/repo",
|
||||
"title": "Test Issue",
|
||||
"body": "This is a test issue.",
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("number", 1),
|
||||
("url", "https://github.com/owner/repo/issues/1"),
|
||||
],
|
||||
test_mock={
|
||||
"create_issue": lambda *args, **kwargs: (
|
||||
1,
|
||||
"https://github.com/owner/repo/issues/1",
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def create_issue(
|
||||
credentials: GithubCredentials, repo_url: str, title: str, body: str
|
||||
) -> tuple[int, str]:
|
||||
api_url = repo_url.replace("github.com", "api.github.com/repos") + "/issues"
|
||||
headers = {
|
||||
"Authorization": credentials.bearer(),
|
||||
"Accept": "application/vnd.github.v3+json",
|
||||
}
|
||||
data = {"title": title, "body": body}
|
||||
|
||||
response = requests.post(api_url, headers=headers, json=data)
|
||||
response.raise_for_status()
|
||||
|
||||
issue = response.json()
|
||||
return issue["number"], issue["html_url"]
|
||||
|
||||
def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: GithubCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
number, url = self.create_issue(
|
||||
credentials,
|
||||
input_data.repo_url,
|
||||
input_data.title,
|
||||
input_data.body,
|
||||
)
|
||||
yield "number", number
|
||||
yield "url", url
|
||||
except Exception as e:
|
||||
yield "error", f"Failed to create issue: {str(e)}"
|
||||
|
||||
|
||||
class GithubReadIssueBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
issue_url: str = SchemaField(
|
||||
description="URL of the GitHub issue",
|
||||
placeholder="https://github.com/owner/repo/issues/1",
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
title: str = SchemaField(description="Title of the issue")
|
||||
body: str = SchemaField(description="Body of the issue")
|
||||
user: str = SchemaField(description="User who created the issue")
|
||||
error: str = SchemaField(
|
||||
description="Error message if reading the issue failed"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="6443c75d-032a-4772-9c08-230c707c8acc",
|
||||
description="This block reads the body, title, and user of a specified GitHub issue.",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=GithubReadIssueBlock.Input,
|
||||
output_schema=GithubReadIssueBlock.Output,
|
||||
test_input={
|
||||
"issue_url": "https://github.com/owner/repo/issues/1",
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("title", "Title of the issue"),
|
||||
("body", "This is the body of the issue."),
|
||||
("user", "username"),
|
||||
],
|
||||
test_mock={
|
||||
"read_issue": lambda *args, **kwargs: (
|
||||
"Title of the issue",
|
||||
"This is the body of the issue.",
|
||||
"username",
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def read_issue(
|
||||
credentials: GithubCredentials, issue_url: str
|
||||
) -> tuple[str, str, str]:
|
||||
api_url = issue_url.replace("github.com", "api.github.com/repos")
|
||||
|
||||
headers = {
|
||||
"Authorization": credentials.bearer(),
|
||||
"Accept": "application/vnd.github.v3+json",
|
||||
}
|
||||
|
||||
response = requests.get(api_url, headers=headers)
|
||||
response.raise_for_status()
|
||||
|
||||
data = response.json()
|
||||
title = data.get("title", "No title found")
|
||||
body = data.get("body", "No body content found")
|
||||
user = data.get("user", {}).get("login", "No user found")
|
||||
|
||||
return title, body, user
|
||||
|
||||
def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: GithubCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
title, body, user = self.read_issue(
|
||||
credentials,
|
||||
input_data.issue_url,
|
||||
)
|
||||
yield "title", title
|
||||
yield "body", body
|
||||
yield "user", user
|
||||
except Exception as e:
|
||||
yield "error", f"Failed to read issue: {str(e)}"
|
||||
|
||||
|
||||
class GithubListIssuesBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
repo_url: str = SchemaField(
|
||||
description="URL of the GitHub repository",
|
||||
placeholder="https://github.com/owner/repo",
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
class IssueItem(TypedDict):
|
||||
title: str
|
||||
url: str
|
||||
|
||||
issue: IssueItem = SchemaField(
|
||||
title="Issue", description="Issues with their title and URL"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if listing issues failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="c215bfd7-0e57-4573-8f8c-f7d4963dcd74",
|
||||
description="This block lists all issues for a specified GitHub repository.",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=GithubListIssuesBlock.Input,
|
||||
output_schema=GithubListIssuesBlock.Output,
|
||||
test_input={
|
||||
"repo_url": "https://github.com/owner/repo",
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
(
|
||||
"issue",
|
||||
{
|
||||
"title": "Issue 1",
|
||||
"url": "https://github.com/owner/repo/issues/1",
|
||||
},
|
||||
)
|
||||
],
|
||||
test_mock={
|
||||
"list_issues": lambda *args, **kwargs: [
|
||||
{
|
||||
"title": "Issue 1",
|
||||
"url": "https://github.com/owner/repo/issues/1",
|
||||
}
|
||||
]
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def list_issues(
|
||||
credentials: GithubCredentials, repo_url: str
|
||||
) -> list[Output.IssueItem]:
|
||||
api_url = repo_url.replace("github.com", "api.github.com/repos") + "/issues"
|
||||
headers = {
|
||||
"Authorization": credentials.bearer(),
|
||||
"Accept": "application/vnd.github.v3+json",
|
||||
}
|
||||
|
||||
response = requests.get(api_url, headers=headers)
|
||||
response.raise_for_status()
|
||||
|
||||
data = response.json()
|
||||
issues: list[GithubListIssuesBlock.Output.IssueItem] = [
|
||||
{"title": issue["title"], "url": issue["html_url"]} for issue in data
|
||||
]
|
||||
|
||||
return issues
|
||||
|
||||
def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: GithubCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
issues = self.list_issues(
|
||||
credentials,
|
||||
input_data.repo_url,
|
||||
)
|
||||
yield from (("issue", issue) for issue in issues)
|
||||
except Exception as e:
|
||||
yield "error", f"Failed to list issues: {str(e)}"
|
||||
|
||||
|
||||
class GithubAddLabelBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
issue_url: str = SchemaField(
|
||||
description="URL of the GitHub issue or pull request",
|
||||
placeholder="https://github.com/owner/repo/issues/1",
|
||||
)
|
||||
label: str = SchemaField(
|
||||
description="Label to add to the issue or pull request",
|
||||
placeholder="Enter the label",
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
status: str = SchemaField(description="Status of the label addition operation")
|
||||
error: str = SchemaField(
|
||||
description="Error message if the label addition failed"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="98bd6b77-9506-43d5-b669-6b9733c4b1f1",
|
||||
description="This block adds a label to a specified GitHub issue or pull request.",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=GithubAddLabelBlock.Input,
|
||||
output_schema=GithubAddLabelBlock.Output,
|
||||
test_input={
|
||||
"issue_url": "https://github.com/owner/repo/issues/1",
|
||||
"label": "bug",
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[("status", "Label added successfully")],
|
||||
test_mock={"add_label": lambda *args, **kwargs: "Label added successfully"},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def add_label(credentials: GithubCredentials, issue_url: str, label: str) -> str:
|
||||
# Convert the provided GitHub URL to the API URL
|
||||
if "/pull/" in issue_url:
|
||||
api_url = (
|
||||
issue_url.replace("github.com", "api.github.com/repos").replace(
|
||||
"/pull/", "/issues/"
|
||||
)
|
||||
+ "/labels"
|
||||
)
|
||||
else:
|
||||
api_url = (
|
||||
issue_url.replace("github.com", "api.github.com/repos") + "/labels"
|
||||
)
|
||||
|
||||
headers = {
|
||||
"Authorization": credentials.bearer(),
|
||||
"Accept": "application/vnd.github.v3+json",
|
||||
}
|
||||
data = {"labels": [label]}
|
||||
|
||||
response = requests.post(api_url, headers=headers, json=data)
|
||||
response.raise_for_status()
|
||||
|
||||
return "Label added successfully"
|
||||
|
||||
def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: GithubCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
status = self.add_label(
|
||||
credentials,
|
||||
input_data.issue_url,
|
||||
input_data.label,
|
||||
)
|
||||
yield "status", status
|
||||
except Exception as e:
|
||||
yield "error", f"Failed to add label: {str(e)}"
|
||||
|
||||
|
||||
class GithubRemoveLabelBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
issue_url: str = SchemaField(
|
||||
description="URL of the GitHub issue or pull request",
|
||||
placeholder="https://github.com/owner/repo/issues/1",
|
||||
)
|
||||
label: str = SchemaField(
|
||||
description="Label to remove from the issue or pull request",
|
||||
placeholder="Enter the label",
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
status: str = SchemaField(description="Status of the label removal operation")
|
||||
error: str = SchemaField(
|
||||
description="Error message if the label removal failed"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="78f050c5-3e3a-48c0-9e5b-ef1ceca5589c",
|
||||
description="This block removes a label from a specified GitHub issue or pull request.",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=GithubRemoveLabelBlock.Input,
|
||||
output_schema=GithubRemoveLabelBlock.Output,
|
||||
test_input={
|
||||
"issue_url": "https://github.com/owner/repo/issues/1",
|
||||
"label": "bug",
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[("status", "Label removed successfully")],
|
||||
test_mock={
|
||||
"remove_label": lambda *args, **kwargs: "Label removed successfully"
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def remove_label(credentials: GithubCredentials, issue_url: str, label: str) -> str:
|
||||
# Convert the provided GitHub URL to the API URL
|
||||
if "/pull/" in issue_url:
|
||||
api_url = (
|
||||
issue_url.replace("github.com", "api.github.com/repos").replace(
|
||||
"/pull/", "/issues/"
|
||||
)
|
||||
+ f"/labels/{label}"
|
||||
)
|
||||
else:
|
||||
api_url = (
|
||||
issue_url.replace("github.com", "api.github.com/repos")
|
||||
+ f"/labels/{label}"
|
||||
)
|
||||
|
||||
# Log the constructed API URL for debugging
|
||||
print(f"Constructed API URL: {api_url}")
|
||||
|
||||
headers = {
|
||||
"Authorization": credentials.bearer(),
|
||||
"Accept": "application/vnd.github.v3+json",
|
||||
}
|
||||
|
||||
response = requests.delete(api_url, headers=headers)
|
||||
response.raise_for_status()
|
||||
|
||||
return "Label removed successfully"
|
||||
|
||||
def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: GithubCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
status = self.remove_label(
|
||||
credentials,
|
||||
input_data.issue_url,
|
||||
input_data.label,
|
||||
)
|
||||
yield "status", status
|
||||
except Exception as e:
|
||||
yield "error", f"Failed to remove label: {str(e)}"
|
||||
|
||||
|
||||
class GithubAssignIssueBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
issue_url: str = SchemaField(
|
||||
description="URL of the GitHub issue",
|
||||
placeholder="https://github.com/owner/repo/issues/1",
|
||||
)
|
||||
assignee: str = SchemaField(
|
||||
description="Username to assign to the issue",
|
||||
placeholder="Enter the username",
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
status: str = SchemaField(
|
||||
description="Status of the issue assignment operation"
|
||||
)
|
||||
error: str = SchemaField(
|
||||
description="Error message if the issue assignment failed"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="90507c72-b0ff-413a-886a-23bbbd66f542",
|
||||
description="This block assigns a user to a specified GitHub issue.",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=GithubAssignIssueBlock.Input,
|
||||
output_schema=GithubAssignIssueBlock.Output,
|
||||
test_input={
|
||||
"issue_url": "https://github.com/owner/repo/issues/1",
|
||||
"assignee": "username1",
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[("status", "Issue assigned successfully")],
|
||||
test_mock={
|
||||
"assign_issue": lambda *args, **kwargs: "Issue assigned successfully"
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def assign_issue(
|
||||
credentials: GithubCredentials,
|
||||
issue_url: str,
|
||||
assignee: str,
|
||||
) -> str:
|
||||
# Extracting repo path and issue number from the issue URL
|
||||
repo_path, issue_number = issue_url.replace("https://github.com/", "").split(
|
||||
"/issues/"
|
||||
)
|
||||
api_url = (
|
||||
f"https://api.github.com/repos/{repo_path}/issues/{issue_number}/assignees"
|
||||
)
|
||||
|
||||
headers = {
|
||||
"Authorization": credentials.bearer(),
|
||||
"Accept": "application/vnd.github.v3+json",
|
||||
}
|
||||
data = {"assignees": [assignee]}
|
||||
|
||||
response = requests.post(api_url, headers=headers, json=data)
|
||||
response.raise_for_status()
|
||||
|
||||
return "Issue assigned successfully"
|
||||
|
||||
def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: GithubCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
status = self.assign_issue(
|
||||
credentials,
|
||||
input_data.issue_url,
|
||||
input_data.assignee,
|
||||
)
|
||||
yield "status", status
|
||||
except Exception as e:
|
||||
yield "error", f"Failed to assign issue: {str(e)}"
|
||||
|
||||
|
||||
class GithubUnassignIssueBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
issue_url: str = SchemaField(
|
||||
description="URL of the GitHub issue",
|
||||
placeholder="https://github.com/owner/repo/issues/1",
|
||||
)
|
||||
assignee: str = SchemaField(
|
||||
description="Username to unassign from the issue",
|
||||
placeholder="Enter the username",
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
status: str = SchemaField(
|
||||
description="Status of the issue unassignment operation"
|
||||
)
|
||||
error: str = SchemaField(
|
||||
description="Error message if the issue unassignment failed"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="d154002a-38f4-46c2-962d-2488f2b05ece",
|
||||
description="This block unassigns a user from a specified GitHub issue.",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=GithubUnassignIssueBlock.Input,
|
||||
output_schema=GithubUnassignIssueBlock.Output,
|
||||
test_input={
|
||||
"issue_url": "https://github.com/owner/repo/issues/1",
|
||||
"assignee": "username1",
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[("status", "Issue unassigned successfully")],
|
||||
test_mock={
|
||||
"unassign_issue": lambda *args, **kwargs: "Issue unassigned successfully"
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def unassign_issue(
|
||||
credentials: GithubCredentials,
|
||||
issue_url: str,
|
||||
assignee: str,
|
||||
) -> str:
|
||||
# Extracting repo path and issue number from the issue URL
|
||||
repo_path, issue_number = issue_url.replace("https://github.com/", "").split(
|
||||
"/issues/"
|
||||
)
|
||||
api_url = (
|
||||
f"https://api.github.com/repos/{repo_path}/issues/{issue_number}/assignees"
|
||||
)
|
||||
|
||||
headers = {
|
||||
"Authorization": credentials.bearer(),
|
||||
"Accept": "application/vnd.github.v3+json",
|
||||
}
|
||||
data = {"assignees": [assignee]}
|
||||
|
||||
response = requests.delete(api_url, headers=headers, json=data)
|
||||
response.raise_for_status()
|
||||
|
||||
return "Issue unassigned successfully"
|
||||
|
||||
def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: GithubCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
status = self.unassign_issue(
|
||||
credentials,
|
||||
input_data.issue_url,
|
||||
input_data.assignee,
|
||||
)
|
||||
yield "status", status
|
||||
except Exception as e:
|
||||
yield "error", f"Failed to unassign issue: {str(e)}"
|
||||
596
autogpt_platform/backend/backend/blocks/github/pull_requests.py
Normal file
596
autogpt_platform/backend/backend/blocks/github/pull_requests.py
Normal file
@@ -0,0 +1,596 @@
|
||||
import requests
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import SchemaField
|
||||
|
||||
from ._auth import (
|
||||
TEST_CREDENTIALS,
|
||||
TEST_CREDENTIALS_INPUT,
|
||||
GithubCredentials,
|
||||
GithubCredentialsField,
|
||||
GithubCredentialsInput,
|
||||
)
|
||||
|
||||
|
||||
class GithubListPullRequestsBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
repo_url: str = SchemaField(
|
||||
description="URL of the GitHub repository",
|
||||
placeholder="https://github.com/owner/repo",
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
class PRItem(TypedDict):
|
||||
title: str
|
||||
url: str
|
||||
|
||||
pull_request: PRItem = SchemaField(
|
||||
title="Pull Request", description="PRs with their title and URL"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if listing issues failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="ffef3c4c-6cd0-48dd-817d-459f975219f4",
|
||||
description="This block lists all pull requests for a specified GitHub repository.",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=GithubListPullRequestsBlock.Input,
|
||||
output_schema=GithubListPullRequestsBlock.Output,
|
||||
test_input={
|
||||
"repo_url": "https://github.com/owner/repo",
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
(
|
||||
"pull_request",
|
||||
{
|
||||
"title": "Pull request 1",
|
||||
"url": "https://github.com/owner/repo/pull/1",
|
||||
},
|
||||
)
|
||||
],
|
||||
test_mock={
|
||||
"list_prs": lambda *args, **kwargs: [
|
||||
{
|
||||
"title": "Pull request 1",
|
||||
"url": "https://github.com/owner/repo/pull/1",
|
||||
}
|
||||
]
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def list_prs(credentials: GithubCredentials, repo_url: str) -> list[Output.PRItem]:
|
||||
api_url = repo_url.replace("github.com", "api.github.com/repos") + "/pulls"
|
||||
headers = {
|
||||
"Authorization": credentials.bearer(),
|
||||
"Accept": "application/vnd.github.v3+json",
|
||||
}
|
||||
|
||||
response = requests.get(api_url, headers=headers)
|
||||
response.raise_for_status()
|
||||
|
||||
data = response.json()
|
||||
pull_requests: list[GithubListPullRequestsBlock.Output.PRItem] = [
|
||||
{"title": pr["title"], "url": pr["html_url"]} for pr in data
|
||||
]
|
||||
|
||||
return pull_requests
|
||||
|
||||
def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: GithubCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
pull_requests = self.list_prs(
|
||||
credentials,
|
||||
input_data.repo_url,
|
||||
)
|
||||
yield from (("pull_request", pr) for pr in pull_requests)
|
||||
except Exception as e:
|
||||
yield "error", f"Failed to list pull requests: {str(e)}"
|
||||
|
||||
|
||||
class GithubMakePullRequestBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
repo_url: str = SchemaField(
|
||||
description="URL of the GitHub repository",
|
||||
placeholder="https://github.com/owner/repo",
|
||||
)
|
||||
title: str = SchemaField(
|
||||
description="Title of the pull request",
|
||||
placeholder="Enter the pull request title",
|
||||
)
|
||||
body: str = SchemaField(
|
||||
description="Body of the pull request",
|
||||
placeholder="Enter the pull request body",
|
||||
)
|
||||
head: str = SchemaField(
|
||||
description="The name of the branch where your changes are implemented. For cross-repository pull requests in the same network, namespace head with a user like this: username:branch.",
|
||||
placeholder="Enter the head branch",
|
||||
)
|
||||
base: str = SchemaField(
|
||||
description="The name of the branch you want the changes pulled into.",
|
||||
placeholder="Enter the base branch",
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
number: int = SchemaField(description="Number of the created pull request")
|
||||
url: str = SchemaField(description="URL of the created pull request")
|
||||
error: str = SchemaField(
|
||||
description="Error message if the pull request creation failed"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="dfb987f8-f197-4b2e-bf19-111812afd692",
|
||||
description="This block creates a new pull request on a specified GitHub repository.",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=GithubMakePullRequestBlock.Input,
|
||||
output_schema=GithubMakePullRequestBlock.Output,
|
||||
test_input={
|
||||
"repo_url": "https://github.com/owner/repo",
|
||||
"title": "Test Pull Request",
|
||||
"body": "This is a test pull request.",
|
||||
"head": "feature-branch",
|
||||
"base": "main",
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("number", 1),
|
||||
("url", "https://github.com/owner/repo/pull/1"),
|
||||
],
|
||||
test_mock={
|
||||
"create_pr": lambda *args, **kwargs: (
|
||||
1,
|
||||
"https://github.com/owner/repo/pull/1",
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def create_pr(
|
||||
credentials: GithubCredentials,
|
||||
repo_url: str,
|
||||
title: str,
|
||||
body: str,
|
||||
head: str,
|
||||
base: str,
|
||||
) -> tuple[int, str]:
|
||||
repo_path = repo_url.replace("https://github.com/", "")
|
||||
api_url = f"https://api.github.com/repos/{repo_path}/pulls"
|
||||
headers = {
|
||||
"Authorization": credentials.bearer(),
|
||||
"Accept": "application/vnd.github.v3+json",
|
||||
}
|
||||
data = {"title": title, "body": body, "head": head, "base": base}
|
||||
|
||||
response = requests.post(api_url, headers=headers, json=data)
|
||||
response.raise_for_status()
|
||||
|
||||
pr_data = response.json()
|
||||
return pr_data["number"], pr_data["html_url"]
|
||||
|
||||
def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: GithubCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
number, url = self.create_pr(
|
||||
credentials,
|
||||
input_data.repo_url,
|
||||
input_data.title,
|
||||
input_data.body,
|
||||
input_data.head,
|
||||
input_data.base,
|
||||
)
|
||||
yield "number", number
|
||||
yield "url", url
|
||||
except requests.exceptions.HTTPError as http_err:
|
||||
if http_err.response.status_code == 422:
|
||||
error_details = http_err.response.json()
|
||||
error_message = error_details.get("message", "Unknown error")
|
||||
else:
|
||||
error_message = str(http_err)
|
||||
yield "error", f"Failed to create pull request: {error_message}"
|
||||
except Exception as e:
|
||||
yield "error", f"Failed to create pull request: {str(e)}"
|
||||
|
||||
|
||||
class GithubReadPullRequestBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
pr_url: str = SchemaField(
|
||||
description="URL of the GitHub pull request",
|
||||
placeholder="https://github.com/owner/repo/pull/1",
|
||||
)
|
||||
include_pr_changes: bool = SchemaField(
|
||||
description="Whether to include the changes made in the pull request",
|
||||
default=False,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
title: str = SchemaField(description="Title of the pull request")
|
||||
body: str = SchemaField(description="Body of the pull request")
|
||||
author: str = SchemaField(description="User who created the pull request")
|
||||
changes: str = SchemaField(description="Changes made in the pull request")
|
||||
error: str = SchemaField(
|
||||
description="Error message if reading the pull request failed"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="bf94b2a4-1a30-4600-a783-a8a44ee31301",
|
||||
description="This block reads the body, title, user, and changes of a specified GitHub pull request.",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=GithubReadPullRequestBlock.Input,
|
||||
output_schema=GithubReadPullRequestBlock.Output,
|
||||
test_input={
|
||||
"pr_url": "https://github.com/owner/repo/pull/1",
|
||||
"include_pr_changes": True,
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("title", "Title of the pull request"),
|
||||
("body", "This is the body of the pull request."),
|
||||
("author", "username"),
|
||||
("changes", "List of changes made in the pull request."),
|
||||
],
|
||||
test_mock={
|
||||
"read_pr": lambda *args, **kwargs: (
|
||||
"Title of the pull request",
|
||||
"This is the body of the pull request.",
|
||||
"username",
|
||||
),
|
||||
"read_pr_changes": lambda *args, **kwargs: "List of changes made in the pull request.",
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def read_pr(credentials: GithubCredentials, pr_url: str) -> tuple[str, str, str]:
|
||||
api_url = pr_url.replace("github.com", "api.github.com/repos").replace(
|
||||
"/pull/", "/issues/"
|
||||
)
|
||||
|
||||
headers = {
|
||||
"Authorization": credentials.bearer(),
|
||||
"Accept": "application/vnd.github.v3+json",
|
||||
}
|
||||
|
||||
response = requests.get(api_url, headers=headers)
|
||||
response.raise_for_status()
|
||||
|
||||
data = response.json()
|
||||
title = data.get("title", "No title found")
|
||||
body = data.get("body", "No body content found")
|
||||
author = data.get("user", {}).get("login", "No user found")
|
||||
|
||||
return title, body, author
|
||||
|
||||
@staticmethod
|
||||
def read_pr_changes(credentials: GithubCredentials, pr_url: str) -> str:
|
||||
api_url = (
|
||||
pr_url.replace("github.com", "api.github.com/repos").replace(
|
||||
"/pull/", "/pulls/"
|
||||
)
|
||||
+ "/files"
|
||||
)
|
||||
|
||||
headers = {
|
||||
"Authorization": credentials.bearer(),
|
||||
"Accept": "application/vnd.github.v3+json",
|
||||
}
|
||||
|
||||
response = requests.get(api_url, headers=headers)
|
||||
response.raise_for_status()
|
||||
|
||||
files = response.json()
|
||||
changes = []
|
||||
for file in files:
|
||||
filename = file.get("filename")
|
||||
patch = file.get("patch")
|
||||
if filename and patch:
|
||||
changes.append(f"File: {filename}\n{patch}")
|
||||
|
||||
return "\n\n".join(changes)
|
||||
|
||||
def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: GithubCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
title, body, author = self.read_pr(
|
||||
credentials,
|
||||
input_data.pr_url,
|
||||
)
|
||||
yield "title", title
|
||||
yield "body", body
|
||||
yield "author", author
|
||||
|
||||
if input_data.include_pr_changes:
|
||||
changes = self.read_pr_changes(
|
||||
credentials,
|
||||
input_data.pr_url,
|
||||
)
|
||||
yield "changes", changes
|
||||
except Exception as e:
|
||||
yield "error", f"Failed to read pull request: {str(e)}"
|
||||
|
||||
|
||||
class GithubAssignPRReviewerBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
pr_url: str = SchemaField(
|
||||
description="URL of the GitHub pull request",
|
||||
placeholder="https://github.com/owner/repo/pull/1",
|
||||
)
|
||||
reviewer: str = SchemaField(
|
||||
description="Username of the reviewer to assign",
|
||||
placeholder="Enter the reviewer's username",
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
status: str = SchemaField(
|
||||
description="Status of the reviewer assignment operation"
|
||||
)
|
||||
error: str = SchemaField(
|
||||
description="Error message if the reviewer assignment failed"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="c0d22c5e-e688-43e3-ba43-d5faba7927fd",
|
||||
description="This block assigns a reviewer to a specified GitHub pull request.",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=GithubAssignPRReviewerBlock.Input,
|
||||
output_schema=GithubAssignPRReviewerBlock.Output,
|
||||
test_input={
|
||||
"pr_url": "https://github.com/owner/repo/pull/1",
|
||||
"reviewer": "reviewer_username",
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[("status", "Reviewer assigned successfully")],
|
||||
test_mock={
|
||||
"assign_reviewer": lambda *args, **kwargs: "Reviewer assigned successfully"
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def assign_reviewer(
|
||||
credentials: GithubCredentials, pr_url: str, reviewer: str
|
||||
) -> str:
|
||||
# Convert the PR URL to the appropriate API endpoint
|
||||
api_url = (
|
||||
pr_url.replace("github.com", "api.github.com/repos").replace(
|
||||
"/pull/", "/pulls/"
|
||||
)
|
||||
+ "/requested_reviewers"
|
||||
)
|
||||
|
||||
headers = {
|
||||
"Authorization": credentials.bearer(),
|
||||
"Accept": "application/vnd.github.v3+json",
|
||||
}
|
||||
data = {"reviewers": [reviewer]}
|
||||
|
||||
response = requests.post(api_url, headers=headers, json=data)
|
||||
response.raise_for_status()
|
||||
|
||||
return "Reviewer assigned successfully"
|
||||
|
||||
def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: GithubCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
status = self.assign_reviewer(
|
||||
credentials,
|
||||
input_data.pr_url,
|
||||
input_data.reviewer,
|
||||
)
|
||||
yield "status", status
|
||||
except requests.exceptions.HTTPError as http_err:
|
||||
if http_err.response.status_code == 422:
|
||||
error_msg = (
|
||||
"Failed to assign reviewer: "
|
||||
f"The reviewer '{input_data.reviewer}' may not have permission "
|
||||
"or the pull request is not in a valid state. "
|
||||
f"Detailed error: {http_err.response.text}"
|
||||
)
|
||||
else:
|
||||
error_msg = f"HTTP error: {http_err} - {http_err.response.text}"
|
||||
yield "error", error_msg
|
||||
except Exception as e:
|
||||
yield "error", f"Failed to assign reviewer: {str(e)}"
|
||||
|
||||
|
||||
class GithubUnassignPRReviewerBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
pr_url: str = SchemaField(
|
||||
description="URL of the GitHub pull request",
|
||||
placeholder="https://github.com/owner/repo/pull/1",
|
||||
)
|
||||
reviewer: str = SchemaField(
|
||||
description="Username of the reviewer to unassign",
|
||||
placeholder="Enter the reviewer's username",
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
status: str = SchemaField(
|
||||
description="Status of the reviewer unassignment operation"
|
||||
)
|
||||
error: str = SchemaField(
|
||||
description="Error message if the reviewer unassignment failed"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="9637945d-c602-4875-899a-9c22f8fd30de",
|
||||
description="This block unassigns a reviewer from a specified GitHub pull request.",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=GithubUnassignPRReviewerBlock.Input,
|
||||
output_schema=GithubUnassignPRReviewerBlock.Output,
|
||||
test_input={
|
||||
"pr_url": "https://github.com/owner/repo/pull/1",
|
||||
"reviewer": "reviewer_username",
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[("status", "Reviewer unassigned successfully")],
|
||||
test_mock={
|
||||
"unassign_reviewer": lambda *args, **kwargs: "Reviewer unassigned successfully"
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def unassign_reviewer(
|
||||
credentials: GithubCredentials, pr_url: str, reviewer: str
|
||||
) -> str:
|
||||
api_url = (
|
||||
pr_url.replace("github.com", "api.github.com/repos").replace(
|
||||
"/pull/", "/pulls/"
|
||||
)
|
||||
+ "/requested_reviewers"
|
||||
)
|
||||
headers = {
|
||||
"Authorization": credentials.bearer(),
|
||||
"Accept": "application/vnd.github.v3+json",
|
||||
}
|
||||
data = {"reviewers": [reviewer]}
|
||||
|
||||
response = requests.delete(api_url, headers=headers, json=data)
|
||||
response.raise_for_status()
|
||||
|
||||
return "Reviewer unassigned successfully"
|
||||
|
||||
def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: GithubCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
status = self.unassign_reviewer(
|
||||
credentials,
|
||||
input_data.pr_url,
|
||||
input_data.reviewer,
|
||||
)
|
||||
yield "status", status
|
||||
except Exception as e:
|
||||
yield "error", f"Failed to unassign reviewer: {str(e)}"
|
||||
|
||||
|
||||
class GithubListPRReviewersBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
pr_url: str = SchemaField(
|
||||
description="URL of the GitHub pull request",
|
||||
placeholder="https://github.com/owner/repo/pull/1",
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
class ReviewerItem(TypedDict):
|
||||
username: str
|
||||
url: str
|
||||
|
||||
reviewer: ReviewerItem = SchemaField(
|
||||
title="Reviewer",
|
||||
description="Reviewers with their username and profile URL",
|
||||
)
|
||||
error: str = SchemaField(
|
||||
description="Error message if listing reviewers failed"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="2646956e-96d5-4754-a3df-034017e7ed96",
|
||||
description="This block lists all reviewers for a specified GitHub pull request.",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=GithubListPRReviewersBlock.Input,
|
||||
output_schema=GithubListPRReviewersBlock.Output,
|
||||
test_input={
|
||||
"pr_url": "https://github.com/owner/repo/pull/1",
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
(
|
||||
"reviewer",
|
||||
{
|
||||
"username": "reviewer1",
|
||||
"url": "https://github.com/reviewer1",
|
||||
},
|
||||
)
|
||||
],
|
||||
test_mock={
|
||||
"list_reviewers": lambda *args, **kwargs: [
|
||||
{
|
||||
"username": "reviewer1",
|
||||
"url": "https://github.com/reviewer1",
|
||||
}
|
||||
]
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def list_reviewers(
|
||||
credentials: GithubCredentials, pr_url: str
|
||||
) -> list[Output.ReviewerItem]:
|
||||
api_url = (
|
||||
pr_url.replace("github.com", "api.github.com/repos").replace(
|
||||
"/pull/", "/pulls/"
|
||||
)
|
||||
+ "/requested_reviewers"
|
||||
)
|
||||
headers = {
|
||||
"Authorization": credentials.bearer(),
|
||||
"Accept": "application/vnd.github.v3+json",
|
||||
}
|
||||
|
||||
response = requests.get(api_url, headers=headers)
|
||||
response.raise_for_status()
|
||||
|
||||
data = response.json()
|
||||
reviewers: list[GithubListPRReviewersBlock.Output.ReviewerItem] = [
|
||||
{"username": reviewer["login"], "url": reviewer["html_url"]}
|
||||
for reviewer in data.get("users", [])
|
||||
]
|
||||
|
||||
return reviewers
|
||||
|
||||
def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: GithubCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
reviewers = self.list_reviewers(
|
||||
credentials,
|
||||
input_data.pr_url,
|
||||
)
|
||||
yield from (("reviewer", reviewer) for reviewer in reviewers)
|
||||
except Exception as e:
|
||||
yield "error", f"Failed to list reviewers: {str(e)}"
|
||||
786
autogpt_platform/backend/backend/blocks/github/repo.py
Normal file
786
autogpt_platform/backend/backend/blocks/github/repo.py
Normal file
@@ -0,0 +1,786 @@
|
||||
import base64
|
||||
|
||||
import requests
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import SchemaField
|
||||
|
||||
from ._auth import (
|
||||
TEST_CREDENTIALS,
|
||||
TEST_CREDENTIALS_INPUT,
|
||||
GithubCredentials,
|
||||
GithubCredentialsField,
|
||||
GithubCredentialsInput,
|
||||
)
|
||||
|
||||
|
||||
class GithubListTagsBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
repo_url: str = SchemaField(
|
||||
description="URL of the GitHub repository",
|
||||
placeholder="https://github.com/owner/repo",
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
class TagItem(TypedDict):
|
||||
name: str
|
||||
url: str
|
||||
|
||||
tag: TagItem = SchemaField(
|
||||
title="Tag", description="Tags with their name and file tree browser URL"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if listing tags failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="358924e7-9a11-4d1a-a0f2-13c67fe59e2e",
|
||||
description="This block lists all tags for a specified GitHub repository.",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=GithubListTagsBlock.Input,
|
||||
output_schema=GithubListTagsBlock.Output,
|
||||
test_input={
|
||||
"repo_url": "https://github.com/owner/repo",
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
(
|
||||
"tag",
|
||||
{
|
||||
"name": "v1.0.0",
|
||||
"url": "https://github.com/owner/repo/tree/v1.0.0",
|
||||
},
|
||||
)
|
||||
],
|
||||
test_mock={
|
||||
"list_tags": lambda *args, **kwargs: [
|
||||
{
|
||||
"name": "v1.0.0",
|
||||
"url": "https://github.com/owner/repo/tree/v1.0.0",
|
||||
}
|
||||
]
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def list_tags(
|
||||
credentials: GithubCredentials, repo_url: str
|
||||
) -> list[Output.TagItem]:
|
||||
repo_path = repo_url.replace("https://github.com/", "")
|
||||
api_url = f"https://api.github.com/repos/{repo_path}/tags"
|
||||
headers = {
|
||||
"Authorization": credentials.bearer(),
|
||||
"Accept": "application/vnd.github.v3+json",
|
||||
}
|
||||
|
||||
response = requests.get(api_url, headers=headers)
|
||||
response.raise_for_status()
|
||||
|
||||
data = response.json()
|
||||
tags: list[GithubListTagsBlock.Output.TagItem] = [
|
||||
{
|
||||
"name": tag["name"],
|
||||
"url": f"https://github.com/{repo_path}/tree/{tag['name']}",
|
||||
}
|
||||
for tag in data
|
||||
]
|
||||
|
||||
return tags
|
||||
|
||||
def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: GithubCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
tags = self.list_tags(
|
||||
credentials,
|
||||
input_data.repo_url,
|
||||
)
|
||||
yield from (("tag", tag) for tag in tags)
|
||||
except Exception as e:
|
||||
yield "error", f"Failed to list tags: {str(e)}"
|
||||
|
||||
|
||||
class GithubListBranchesBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
repo_url: str = SchemaField(
|
||||
description="URL of the GitHub repository",
|
||||
placeholder="https://github.com/owner/repo",
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
class BranchItem(TypedDict):
|
||||
name: str
|
||||
url: str
|
||||
|
||||
branch: BranchItem = SchemaField(
|
||||
title="Branch",
|
||||
description="Branches with their name and file tree browser URL",
|
||||
)
|
||||
error: str = SchemaField(description="Error message if listing branches failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="74243e49-2bec-4916-8bf4-db43d44aead5",
|
||||
description="This block lists all branches for a specified GitHub repository.",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=GithubListBranchesBlock.Input,
|
||||
output_schema=GithubListBranchesBlock.Output,
|
||||
test_input={
|
||||
"repo_url": "https://github.com/owner/repo",
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
(
|
||||
"branch",
|
||||
{
|
||||
"name": "main",
|
||||
"url": "https://github.com/owner/repo/tree/main",
|
||||
},
|
||||
)
|
||||
],
|
||||
test_mock={
|
||||
"list_branches": lambda *args, **kwargs: [
|
||||
{
|
||||
"name": "main",
|
||||
"url": "https://github.com/owner/repo/tree/main",
|
||||
}
|
||||
]
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def list_branches(
|
||||
credentials: GithubCredentials, repo_url: str
|
||||
) -> list[Output.BranchItem]:
|
||||
api_url = repo_url.replace("github.com", "api.github.com/repos") + "/branches"
|
||||
headers = {
|
||||
"Authorization": credentials.bearer(),
|
||||
"Accept": "application/vnd.github.v3+json",
|
||||
}
|
||||
|
||||
response = requests.get(api_url, headers=headers)
|
||||
response.raise_for_status()
|
||||
|
||||
data = response.json()
|
||||
branches: list[GithubListBranchesBlock.Output.BranchItem] = [
|
||||
{"name": branch["name"], "url": branch["commit"]["url"]} for branch in data
|
||||
]
|
||||
|
||||
return branches
|
||||
|
||||
def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: GithubCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
branches = self.list_branches(
|
||||
credentials,
|
||||
input_data.repo_url,
|
||||
)
|
||||
yield from (("branch", branch) for branch in branches)
|
||||
except Exception as e:
|
||||
yield "error", f"Failed to list branches: {str(e)}"
|
||||
|
||||
|
||||
class GithubListDiscussionsBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
repo_url: str = SchemaField(
|
||||
description="URL of the GitHub repository",
|
||||
placeholder="https://github.com/owner/repo",
|
||||
)
|
||||
num_discussions: int = SchemaField(
|
||||
description="Number of discussions to fetch", default=5
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
class DiscussionItem(TypedDict):
|
||||
title: str
|
||||
url: str
|
||||
|
||||
discussion: DiscussionItem = SchemaField(
|
||||
title="Discussion", description="Discussions with their title and URL"
|
||||
)
|
||||
error: str = SchemaField(
|
||||
description="Error message if listing discussions failed"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="3ef1a419-3d76-4e07-b761-de9dad4d51d7",
|
||||
description="This block lists recent discussions for a specified GitHub repository.",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=GithubListDiscussionsBlock.Input,
|
||||
output_schema=GithubListDiscussionsBlock.Output,
|
||||
test_input={
|
||||
"repo_url": "https://github.com/owner/repo",
|
||||
"num_discussions": 3,
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
(
|
||||
"discussion",
|
||||
{
|
||||
"title": "Discussion 1",
|
||||
"url": "https://github.com/owner/repo/discussions/1",
|
||||
},
|
||||
)
|
||||
],
|
||||
test_mock={
|
||||
"list_discussions": lambda *args, **kwargs: [
|
||||
{
|
||||
"title": "Discussion 1",
|
||||
"url": "https://github.com/owner/repo/discussions/1",
|
||||
}
|
||||
]
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def list_discussions(
|
||||
credentials: GithubCredentials, repo_url: str, num_discussions: int
|
||||
) -> list[Output.DiscussionItem]:
|
||||
repo_path = repo_url.replace("https://github.com/", "")
|
||||
owner, repo = repo_path.split("/")
|
||||
query = """
|
||||
query($owner: String!, $repo: String!, $num: Int!) {
|
||||
repository(owner: $owner, name: $repo) {
|
||||
discussions(first: $num) {
|
||||
nodes {
|
||||
title
|
||||
url
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
"""
|
||||
variables = {"owner": owner, "repo": repo, "num": num_discussions}
|
||||
headers = {
|
||||
"Authorization": credentials.bearer(),
|
||||
"Accept": "application/vnd.github.v3+json",
|
||||
}
|
||||
|
||||
response = requests.post(
|
||||
"https://api.github.com/graphql",
|
||||
json={"query": query, "variables": variables},
|
||||
headers=headers,
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
data = response.json()
|
||||
discussions: list[GithubListDiscussionsBlock.Output.DiscussionItem] = [
|
||||
{"title": discussion["title"], "url": discussion["url"]}
|
||||
for discussion in data["data"]["repository"]["discussions"]["nodes"]
|
||||
]
|
||||
|
||||
return discussions
|
||||
|
||||
def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: GithubCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
discussions = self.list_discussions(
|
||||
credentials, input_data.repo_url, input_data.num_discussions
|
||||
)
|
||||
yield from (("discussion", discussion) for discussion in discussions)
|
||||
except Exception as e:
|
||||
yield "error", f"Failed to list discussions: {str(e)}"
|
||||
|
||||
|
||||
class GithubListReleasesBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
repo_url: str = SchemaField(
|
||||
description="URL of the GitHub repository",
|
||||
placeholder="https://github.com/owner/repo",
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
class ReleaseItem(TypedDict):
|
||||
name: str
|
||||
url: str
|
||||
|
||||
release: ReleaseItem = SchemaField(
|
||||
title="Release",
|
||||
description="Releases with their name and file tree browser URL",
|
||||
)
|
||||
error: str = SchemaField(description="Error message if listing releases failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="3460367a-6ba7-4645-8ce6-47b05d040b92",
|
||||
description="This block lists all releases for a specified GitHub repository.",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=GithubListReleasesBlock.Input,
|
||||
output_schema=GithubListReleasesBlock.Output,
|
||||
test_input={
|
||||
"repo_url": "https://github.com/owner/repo",
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
(
|
||||
"release",
|
||||
{
|
||||
"name": "v1.0.0",
|
||||
"url": "https://github.com/owner/repo/releases/tag/v1.0.0",
|
||||
},
|
||||
)
|
||||
],
|
||||
test_mock={
|
||||
"list_releases": lambda *args, **kwargs: [
|
||||
{
|
||||
"name": "v1.0.0",
|
||||
"url": "https://github.com/owner/repo/releases/tag/v1.0.0",
|
||||
}
|
||||
]
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def list_releases(
|
||||
credentials: GithubCredentials, repo_url: str
|
||||
) -> list[Output.ReleaseItem]:
|
||||
repo_path = repo_url.replace("https://github.com/", "")
|
||||
api_url = f"https://api.github.com/repos/{repo_path}/releases"
|
||||
headers = {
|
||||
"Authorization": credentials.bearer(),
|
||||
"Accept": "application/vnd.github.v3+json",
|
||||
}
|
||||
|
||||
response = requests.get(api_url, headers=headers)
|
||||
response.raise_for_status()
|
||||
|
||||
data = response.json()
|
||||
releases: list[GithubListReleasesBlock.Output.ReleaseItem] = [
|
||||
{"name": release["name"], "url": release["html_url"]} for release in data
|
||||
]
|
||||
|
||||
return releases
|
||||
|
||||
def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: GithubCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
releases = self.list_releases(
|
||||
credentials,
|
||||
input_data.repo_url,
|
||||
)
|
||||
yield from (("release", release) for release in releases)
|
||||
except Exception as e:
|
||||
yield "error", f"Failed to list releases: {str(e)}"
|
||||
|
||||
|
||||
class GithubReadFileBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
repo_url: str = SchemaField(
|
||||
description="URL of the GitHub repository",
|
||||
placeholder="https://github.com/owner/repo",
|
||||
)
|
||||
file_path: str = SchemaField(
|
||||
description="Path to the file in the repository",
|
||||
placeholder="path/to/file",
|
||||
)
|
||||
branch: str = SchemaField(
|
||||
description="Branch to read from",
|
||||
placeholder="branch_name",
|
||||
default="master",
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
text_content: str = SchemaField(
|
||||
description="Content of the file (decoded as UTF-8 text)"
|
||||
)
|
||||
raw_content: str = SchemaField(
|
||||
description="Raw base64-encoded content of the file"
|
||||
)
|
||||
size: int = SchemaField(description="The size of the file (in bytes)")
|
||||
error: str = SchemaField(description="Error message if the file reading failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="87ce6c27-5752-4bbc-8e26-6da40a3dcfd3",
|
||||
description="This block reads the content of a specified file from a GitHub repository.",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=GithubReadFileBlock.Input,
|
||||
output_schema=GithubReadFileBlock.Output,
|
||||
test_input={
|
||||
"repo_url": "https://github.com/owner/repo",
|
||||
"file_path": "path/to/file",
|
||||
"branch": "master",
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("raw_content", "RmlsZSBjb250ZW50"),
|
||||
("text_content", "File content"),
|
||||
("size", 13),
|
||||
],
|
||||
test_mock={"read_file": lambda *args, **kwargs: ("RmlsZSBjb250ZW50", 13)},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def read_file(
|
||||
credentials: GithubCredentials, repo_url: str, file_path: str, branch: str
|
||||
) -> tuple[str, int]:
|
||||
repo_path = repo_url.replace("https://github.com/", "")
|
||||
api_url = f"https://api.github.com/repos/{repo_path}/contents/{file_path}?ref={branch}"
|
||||
headers = {
|
||||
"Authorization": credentials.bearer(),
|
||||
"Accept": "application/vnd.github.v3+json",
|
||||
}
|
||||
|
||||
response = requests.get(api_url, headers=headers)
|
||||
response.raise_for_status()
|
||||
|
||||
content = response.json()
|
||||
|
||||
if isinstance(content, list):
|
||||
# Multiple entries of different types exist at this path
|
||||
if not (file := next((f for f in content if f["type"] == "file"), None)):
|
||||
raise TypeError("Not a file")
|
||||
content = file
|
||||
|
||||
if content["type"] != "file":
|
||||
raise TypeError("Not a file")
|
||||
|
||||
return content["content"], content["size"]
|
||||
|
||||
def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: GithubCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
raw_content, size = self.read_file(
|
||||
credentials,
|
||||
input_data.repo_url,
|
||||
input_data.file_path.lstrip("/"),
|
||||
input_data.branch,
|
||||
)
|
||||
yield "raw_content", raw_content
|
||||
yield "text_content", base64.b64decode(raw_content).decode("utf-8")
|
||||
yield "size", size
|
||||
except Exception as e:
|
||||
yield "error", f"Failed to read file: {str(e)}"
|
||||
|
||||
|
||||
class GithubReadFolderBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
repo_url: str = SchemaField(
|
||||
description="URL of the GitHub repository",
|
||||
placeholder="https://github.com/owner/repo",
|
||||
)
|
||||
folder_path: str = SchemaField(
|
||||
description="Path to the folder in the repository",
|
||||
placeholder="path/to/folder",
|
||||
)
|
||||
branch: str = SchemaField(
|
||||
description="Branch name to read from (defaults to master)",
|
||||
placeholder="branch_name",
|
||||
default="master",
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
class DirEntry(TypedDict):
|
||||
name: str
|
||||
path: str
|
||||
|
||||
class FileEntry(TypedDict):
|
||||
name: str
|
||||
path: str
|
||||
size: int
|
||||
|
||||
file: FileEntry = SchemaField(description="Files in the folder")
|
||||
dir: DirEntry = SchemaField(description="Directories in the folder")
|
||||
error: str = SchemaField(
|
||||
description="Error message if reading the folder failed"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="1355f863-2db3-4d75-9fba-f91e8a8ca400",
|
||||
description="This block reads the content of a specified folder from a GitHub repository.",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=GithubReadFolderBlock.Input,
|
||||
output_schema=GithubReadFolderBlock.Output,
|
||||
test_input={
|
||||
"repo_url": "https://github.com/owner/repo",
|
||||
"folder_path": "path/to/folder",
|
||||
"branch": "master",
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
(
|
||||
"file",
|
||||
{
|
||||
"name": "file1.txt",
|
||||
"path": "path/to/folder/file1.txt",
|
||||
"size": 1337,
|
||||
},
|
||||
),
|
||||
("dir", {"name": "dir2", "path": "path/to/folder/dir2"}),
|
||||
],
|
||||
test_mock={
|
||||
"read_folder": lambda *args, **kwargs: (
|
||||
[
|
||||
{
|
||||
"name": "file1.txt",
|
||||
"path": "path/to/folder/file1.txt",
|
||||
"size": 1337,
|
||||
}
|
||||
],
|
||||
[{"name": "dir2", "path": "path/to/folder/dir2"}],
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def read_folder(
|
||||
credentials: GithubCredentials, repo_url: str, folder_path: str, branch: str
|
||||
) -> tuple[list[Output.FileEntry], list[Output.DirEntry]]:
|
||||
repo_path = repo_url.replace("https://github.com/", "")
|
||||
api_url = f"https://api.github.com/repos/{repo_path}/contents/{folder_path}?ref={branch}"
|
||||
headers = {
|
||||
"Authorization": credentials.bearer(),
|
||||
"Accept": "application/vnd.github.v3+json",
|
||||
}
|
||||
|
||||
response = requests.get(api_url, headers=headers)
|
||||
response.raise_for_status()
|
||||
|
||||
content = response.json()
|
||||
|
||||
if isinstance(content, list):
|
||||
# Multiple entries of different types exist at this path
|
||||
if not (dir := next((d for d in content if d["type"] == "dir"), None)):
|
||||
raise TypeError("Not a folder")
|
||||
content = dir
|
||||
|
||||
if content["type"] != "dir":
|
||||
raise TypeError("Not a folder")
|
||||
|
||||
return (
|
||||
[
|
||||
GithubReadFolderBlock.Output.FileEntry(
|
||||
name=entry["name"],
|
||||
path=entry["path"],
|
||||
size=entry["size"],
|
||||
)
|
||||
for entry in content["entries"]
|
||||
if entry["type"] == "file"
|
||||
],
|
||||
[
|
||||
GithubReadFolderBlock.Output.DirEntry(
|
||||
name=entry["name"],
|
||||
path=entry["path"],
|
||||
)
|
||||
for entry in content["entries"]
|
||||
if entry["type"] == "dir"
|
||||
],
|
||||
)
|
||||
|
||||
def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: GithubCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
files, dirs = self.read_folder(
|
||||
credentials,
|
||||
input_data.repo_url,
|
||||
input_data.folder_path.lstrip("/"),
|
||||
input_data.branch,
|
||||
)
|
||||
yield from (("file", file) for file in files)
|
||||
yield from (("dir", dir) for dir in dirs)
|
||||
except Exception as e:
|
||||
yield "error", f"Failed to read folder: {str(e)}"
|
||||
|
||||
|
||||
class GithubMakeBranchBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
repo_url: str = SchemaField(
|
||||
description="URL of the GitHub repository",
|
||||
placeholder="https://github.com/owner/repo",
|
||||
)
|
||||
new_branch: str = SchemaField(
|
||||
description="Name of the new branch",
|
||||
placeholder="new_branch_name",
|
||||
)
|
||||
source_branch: str = SchemaField(
|
||||
description="Name of the source branch",
|
||||
placeholder="source_branch_name",
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
status: str = SchemaField(description="Status of the branch creation operation")
|
||||
error: str = SchemaField(
|
||||
description="Error message if the branch creation failed"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="944cc076-95e7-4d1b-b6b6-b15d8ee5448d",
|
||||
description="This block creates a new branch from a specified source branch.",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=GithubMakeBranchBlock.Input,
|
||||
output_schema=GithubMakeBranchBlock.Output,
|
||||
test_input={
|
||||
"repo_url": "https://github.com/owner/repo",
|
||||
"new_branch": "new_branch_name",
|
||||
"source_branch": "source_branch_name",
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[("status", "Branch created successfully")],
|
||||
test_mock={
|
||||
"create_branch": lambda *args, **kwargs: "Branch created successfully"
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def create_branch(
|
||||
credentials: GithubCredentials,
|
||||
repo_url: str,
|
||||
new_branch: str,
|
||||
source_branch: str,
|
||||
) -> str:
|
||||
repo_path = repo_url.replace("https://github.com/", "")
|
||||
ref_api_url = (
|
||||
f"https://api.github.com/repos/{repo_path}/git/refs/heads/{source_branch}"
|
||||
)
|
||||
headers = {
|
||||
"Authorization": credentials.bearer(),
|
||||
"Accept": "application/vnd.github.v3+json",
|
||||
}
|
||||
|
||||
response = requests.get(ref_api_url, headers=headers)
|
||||
response.raise_for_status()
|
||||
|
||||
sha = response.json()["object"]["sha"]
|
||||
|
||||
create_branch_api_url = f"https://api.github.com/repos/{repo_path}/git/refs"
|
||||
data = {"ref": f"refs/heads/{new_branch}", "sha": sha}
|
||||
|
||||
response = requests.post(create_branch_api_url, headers=headers, json=data)
|
||||
response.raise_for_status()
|
||||
|
||||
return "Branch created successfully"
|
||||
|
||||
def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: GithubCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
status = self.create_branch(
|
||||
credentials,
|
||||
input_data.repo_url,
|
||||
input_data.new_branch,
|
||||
input_data.source_branch,
|
||||
)
|
||||
yield "status", status
|
||||
except Exception as e:
|
||||
yield "error", f"Failed to create branch: {str(e)}"
|
||||
|
||||
|
||||
class GithubDeleteBranchBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
repo_url: str = SchemaField(
|
||||
description="URL of the GitHub repository",
|
||||
placeholder="https://github.com/owner/repo",
|
||||
)
|
||||
branch: str = SchemaField(
|
||||
description="Name of the branch to delete",
|
||||
placeholder="branch_name",
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
status: str = SchemaField(description="Status of the branch deletion operation")
|
||||
error: str = SchemaField(
|
||||
description="Error message if the branch deletion failed"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="0d4130f7-e0ab-4d55-adc3-0a40225e80f4",
|
||||
description="This block deletes a specified branch.",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=GithubDeleteBranchBlock.Input,
|
||||
output_schema=GithubDeleteBranchBlock.Output,
|
||||
test_input={
|
||||
"repo_url": "https://github.com/owner/repo",
|
||||
"branch": "branch_name",
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[("status", "Branch deleted successfully")],
|
||||
test_mock={
|
||||
"delete_branch": lambda *args, **kwargs: "Branch deleted successfully"
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def delete_branch(
|
||||
credentials: GithubCredentials, repo_url: str, branch: str
|
||||
) -> str:
|
||||
repo_path = repo_url.replace("https://github.com/", "")
|
||||
api_url = f"https://api.github.com/repos/{repo_path}/git/refs/heads/{branch}"
|
||||
headers = {
|
||||
"Authorization": credentials.bearer(),
|
||||
"Accept": "application/vnd.github.v3+json",
|
||||
}
|
||||
|
||||
response = requests.delete(api_url, headers=headers)
|
||||
response.raise_for_status()
|
||||
|
||||
return "Branch deleted successfully"
|
||||
|
||||
def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: GithubCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
status = self.delete_branch(
|
||||
credentials,
|
||||
input_data.repo_url,
|
||||
input_data.branch,
|
||||
)
|
||||
yield "status", status
|
||||
except Exception as e:
|
||||
yield "error", f"Failed to delete branch: {str(e)}"
|
||||
@@ -37,7 +37,7 @@ class SendWebRequestBlock(Block):
|
||||
output_schema=SendWebRequestBlock.Output,
|
||||
)
|
||||
|
||||
def run(self, input_data: Input) -> BlockOutput:
|
||||
def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
if isinstance(input_data.body, str):
|
||||
input_data.body = json.loads(input_data.body)
|
||||
|
||||
|
||||
@@ -31,6 +31,6 @@ class ListIteratorBlock(Block):
|
||||
],
|
||||
)
|
||||
|
||||
def run(self, input_data: Input) -> BlockOutput:
|
||||
def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
for index, item in enumerate(input_data.items):
|
||||
yield "item", (index, item)
|
||||
|
||||
@@ -203,7 +203,7 @@ class AIStructuredResponseGeneratorBlock(Block):
|
||||
else:
|
||||
raise ValueError(f"Unsupported LLM provider: {provider}")
|
||||
|
||||
def run(self, input_data: Input) -> BlockOutput:
|
||||
def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
prompt = []
|
||||
|
||||
def trim_prompt(s: str) -> str:
|
||||
@@ -341,7 +341,7 @@ class AITextGeneratorBlock(Block):
|
||||
raise RuntimeError(output_data)
|
||||
raise ValueError("Failed to get a response from the LLM.")
|
||||
|
||||
def run(self, input_data: Input) -> BlockOutput:
|
||||
def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
try:
|
||||
object_input_data = AIStructuredResponseGeneratorBlock.Input(
|
||||
**{attr: getattr(input_data, attr) for attr in input_data.model_fields},
|
||||
@@ -383,7 +383,7 @@ class AITextSummarizerBlock(Block):
|
||||
},
|
||||
)
|
||||
|
||||
def run(self, input_data: Input) -> BlockOutput:
|
||||
def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
try:
|
||||
for output in self._run(input_data):
|
||||
yield output
|
||||
@@ -582,7 +582,7 @@ class AIConversationBlock(Block):
|
||||
else:
|
||||
raise ValueError(f"Unsupported LLM provider: {provider}")
|
||||
|
||||
def run(self, input_data: Input) -> BlockOutput:
|
||||
def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
try:
|
||||
api_key = (
|
||||
input_data.api_key.get_secret_value()
|
||||
|
||||
@@ -51,7 +51,7 @@ class CalculatorBlock(Block):
|
||||
],
|
||||
)
|
||||
|
||||
def run(self, input_data: Input) -> BlockOutput:
|
||||
def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
operation = input_data.operation
|
||||
a = input_data.a
|
||||
b = input_data.b
|
||||
@@ -105,7 +105,7 @@ class CountItemsBlock(Block):
|
||||
],
|
||||
)
|
||||
|
||||
def run(self, input_data: Input) -> BlockOutput:
|
||||
def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
collection = input_data.collection
|
||||
|
||||
try:
|
||||
|
||||
@@ -136,7 +136,7 @@ class PublishToMediumBlock(Block):
|
||||
|
||||
return response.json()
|
||||
|
||||
def run(self, input_data: Input) -> BlockOutput:
|
||||
def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
try:
|
||||
response = self.create_post(
|
||||
input_data.api_key.get_secret_value(),
|
||||
|
||||
@@ -116,7 +116,7 @@ class GetRedditPostsBlock(Block):
|
||||
subreddit = client.subreddit(input_data.subreddit)
|
||||
return subreddit.new(limit=input_data.post_limit)
|
||||
|
||||
def run(self, input_data: Input) -> BlockOutput:
|
||||
def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
current_time = datetime.now(tz=timezone.utc)
|
||||
for post in self.get_posts(input_data):
|
||||
if input_data.last_minutes:
|
||||
@@ -167,5 +167,5 @@ class PostRedditCommentBlock(Block):
|
||||
comment = submission.reply(comment.comment)
|
||||
return comment.id # type: ignore
|
||||
|
||||
def run(self, input_data: Input) -> BlockOutput:
|
||||
def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
yield "comment_id", self.reply_post(input_data.creds, input_data.data)
|
||||
|
||||
@@ -86,7 +86,7 @@ class ReadRSSFeedBlock(Block):
|
||||
def parse_feed(url: str) -> dict[str, Any]:
|
||||
return feedparser.parse(url) # type: ignore
|
||||
|
||||
def run(self, input_data: Input) -> BlockOutput:
|
||||
def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
keep_going = True
|
||||
start_time = datetime.now(timezone.utc) - timedelta(
|
||||
minutes=input_data.time_period
|
||||
|
||||
@@ -93,7 +93,7 @@ class DataSamplingBlock(Block):
|
||||
)
|
||||
self.accumulated_data = []
|
||||
|
||||
def run(self, input_data: Input) -> BlockOutput:
|
||||
def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
if input_data.accumulate:
|
||||
if isinstance(input_data.data, dict):
|
||||
self.accumulated_data.append(input_data.data)
|
||||
|
||||
@@ -35,7 +35,7 @@ class GetWikipediaSummaryBlock(Block, GetRequest):
|
||||
test_mock={"get_request": lambda url, json: {"extract": "summary content"}},
|
||||
)
|
||||
|
||||
def run(self, input_data: Input) -> BlockOutput:
|
||||
def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
try:
|
||||
topic = input_data.topic
|
||||
url = f"https://en.wikipedia.org/api/rest_v1/page/summary/{topic}"
|
||||
@@ -72,7 +72,7 @@ class SearchTheWebBlock(Block, GetRequest):
|
||||
test_mock={"get_request": lambda url, json: "search content"},
|
||||
)
|
||||
|
||||
def run(self, input_data: Input) -> BlockOutput:
|
||||
def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
try:
|
||||
# Encode the search query
|
||||
encoded_query = quote(input_data.query)
|
||||
@@ -113,7 +113,7 @@ class ExtractWebsiteContentBlock(Block, GetRequest):
|
||||
test_mock={"get_request": lambda url, json: "scraped content"},
|
||||
)
|
||||
|
||||
def run(self, input_data: Input) -> BlockOutput:
|
||||
def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
try:
|
||||
# Prepend the Jina-ai Reader URL to the input URL
|
||||
jina_url = f"https://r.jina.ai/{input_data.url}"
|
||||
@@ -166,7 +166,7 @@ class GetWeatherInformationBlock(Block, GetRequest):
|
||||
},
|
||||
)
|
||||
|
||||
def run(self, input_data: Input) -> BlockOutput:
|
||||
def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
try:
|
||||
units = "metric" if input_data.use_celsius else "imperial"
|
||||
api_key = input_data.api_key.get_secret_value()
|
||||
|
||||
@@ -105,7 +105,7 @@ class CreateTalkingAvatarVideoBlock(Block):
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
def run(self, input_data: Input) -> BlockOutput:
|
||||
def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
try:
|
||||
# Create the clip
|
||||
payload = {
|
||||
|
||||
@@ -45,7 +45,7 @@ class MatchTextPatternBlock(Block):
|
||||
],
|
||||
)
|
||||
|
||||
def run(self, input_data: Input) -> BlockOutput:
|
||||
def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
output = input_data.data or input_data.text
|
||||
flags = 0
|
||||
if not input_data.case_sensitive:
|
||||
@@ -97,7 +97,7 @@ class ExtractTextInformationBlock(Block):
|
||||
],
|
||||
)
|
||||
|
||||
def run(self, input_data: Input) -> BlockOutput:
|
||||
def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
flags = 0
|
||||
if not input_data.case_sensitive:
|
||||
flags = flags | re.IGNORECASE
|
||||
@@ -147,7 +147,7 @@ class FillTextTemplateBlock(Block):
|
||||
],
|
||||
)
|
||||
|
||||
def run(self, input_data: Input) -> BlockOutput:
|
||||
def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
# For python.format compatibility: replace all {...} with {{..}}.
|
||||
# But avoid replacing {{...}} to {{{...}}}.
|
||||
fmt = re.sub(r"(?<!{){[ a-zA-Z0-9_]+}", r"{\g<0>}", input_data.format)
|
||||
@@ -180,6 +180,6 @@ class CombineTextsBlock(Block):
|
||||
],
|
||||
)
|
||||
|
||||
def run(self, input_data: Input) -> BlockOutput:
|
||||
def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
combined_text = input_data.delimiter.join(input_data.input)
|
||||
yield "output", combined_text
|
||||
|
||||
@@ -27,7 +27,7 @@ class GetCurrentTimeBlock(Block):
|
||||
],
|
||||
)
|
||||
|
||||
def run(self, input_data: Input) -> BlockOutput:
|
||||
def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
current_time = time.strftime("%H:%M:%S")
|
||||
yield "time", current_time
|
||||
|
||||
@@ -59,7 +59,7 @@ class GetCurrentDateBlock(Block):
|
||||
],
|
||||
)
|
||||
|
||||
def run(self, input_data: Input) -> BlockOutput:
|
||||
def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
try:
|
||||
offset = int(input_data.offset)
|
||||
except ValueError:
|
||||
@@ -96,7 +96,7 @@ class GetCurrentDateAndTimeBlock(Block):
|
||||
],
|
||||
)
|
||||
|
||||
def run(self, input_data: Input) -> BlockOutput:
|
||||
def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
current_date_time = time.strftime("%Y-%m-%d %H:%M:%S")
|
||||
yield "date_time", current_date_time
|
||||
|
||||
@@ -129,7 +129,7 @@ class CountdownTimerBlock(Block):
|
||||
],
|
||||
)
|
||||
|
||||
def run(self, input_data: Input) -> BlockOutput:
|
||||
def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
seconds = int(input_data.seconds)
|
||||
minutes = int(input_data.minutes)
|
||||
hours = int(input_data.hours)
|
||||
|
||||
@@ -62,7 +62,7 @@ class TranscribeYouTubeVideoBlock(Block):
|
||||
def get_transcript(video_id: str):
|
||||
return YouTubeTranscriptApi.get_transcript(video_id)
|
||||
|
||||
def run(self, input_data: Input) -> BlockOutput:
|
||||
def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
try:
|
||||
video_id = self.extract_video_id(input_data.youtube_url)
|
||||
yield "video_id", video_id
|
||||
|
||||
@@ -1,15 +1,28 @@
|
||||
import inspect
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
from typing import Any, ClassVar, Generator, Generic, Type, TypeVar, cast
|
||||
from typing import (
|
||||
Any,
|
||||
ClassVar,
|
||||
Generator,
|
||||
Generic,
|
||||
Optional,
|
||||
Type,
|
||||
TypeVar,
|
||||
cast,
|
||||
get_origin,
|
||||
)
|
||||
|
||||
import jsonref
|
||||
import jsonschema
|
||||
from autogpt_libs.supabase_integration_credentials_store.types import Credentials
|
||||
from prisma.models import AgentBlock
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.data.model import ContributorDetails
|
||||
from backend.util import json
|
||||
|
||||
from .model import CREDENTIALS_FIELD_NAME, ContributorDetails, CredentialsMetaInput
|
||||
|
||||
BlockData = tuple[str, Any] # Input & Output data should be a tuple of (name, data).
|
||||
BlockInput = dict[str, Any] # Input: 1 input pin consumes 1 data.
|
||||
BlockOutput = Generator[BlockData, None, None] # Output: 1 output pin produces n data.
|
||||
@@ -36,6 +49,7 @@ class BlockCategory(Enum):
|
||||
INPUT = "Block that interacts with input of the graph."
|
||||
OUTPUT = "Block that interacts with output of the graph."
|
||||
LOGIC = "Programming logic to control the flow of your agent"
|
||||
DEVELOPER_TOOLS = "Developer tools such as GitHub blocks."
|
||||
|
||||
def dict(self) -> dict[str, str]:
|
||||
return {"category": self.name, "description": self.value}
|
||||
@@ -49,7 +63,7 @@ class BlockSchema(BaseModel):
|
||||
if cls.cached_jsonschema:
|
||||
return cls.cached_jsonschema
|
||||
|
||||
model = jsonref.replace_refs(cls.model_json_schema())
|
||||
model = jsonref.replace_refs(cls.model_json_schema(), merge_props=True)
|
||||
|
||||
def ref_to_dict(obj):
|
||||
if isinstance(obj, dict):
|
||||
@@ -122,6 +136,46 @@ class BlockSchema(BaseModel):
|
||||
if field_info.is_required()
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def __pydantic_init_subclass__(cls, **kwargs):
|
||||
"""Validates the schema definition. Rules:
|
||||
- Only one `CredentialsMetaInput` field may be present.
|
||||
- This field MUST be called `credentials`.
|
||||
- A field that is called `credentials` MUST be a `CredentialsMetaInput`.
|
||||
"""
|
||||
super().__pydantic_init_subclass__(**kwargs)
|
||||
credentials_fields = [
|
||||
field_name
|
||||
for field_name, info in cls.model_fields.items()
|
||||
if (
|
||||
inspect.isclass(info.annotation)
|
||||
and issubclass(
|
||||
get_origin(info.annotation) or info.annotation,
|
||||
CredentialsMetaInput,
|
||||
)
|
||||
)
|
||||
]
|
||||
if len(credentials_fields) > 1:
|
||||
raise ValueError(
|
||||
f"{cls.__qualname__} can only have one CredentialsMetaInput field"
|
||||
)
|
||||
elif (
|
||||
len(credentials_fields) == 1
|
||||
and credentials_fields[0] != CREDENTIALS_FIELD_NAME
|
||||
):
|
||||
raise ValueError(
|
||||
f"CredentialsMetaInput field on {cls.__qualname__} "
|
||||
"must be named 'credentials'"
|
||||
)
|
||||
elif (
|
||||
len(credentials_fields) == 0
|
||||
and CREDENTIALS_FIELD_NAME in cls.model_fields.keys()
|
||||
):
|
||||
raise TypeError(
|
||||
f"Field 'credentials' on {cls.__qualname__} "
|
||||
f"must be of type {CredentialsMetaInput.__name__}"
|
||||
)
|
||||
|
||||
|
||||
BlockSchemaInputType = TypeVar("BlockSchemaInputType", bound=BlockSchema)
|
||||
BlockSchemaOutputType = TypeVar("BlockSchemaOutputType", bound=BlockSchema)
|
||||
@@ -143,6 +197,7 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
|
||||
test_input: BlockInput | list[BlockInput] | None = None,
|
||||
test_output: BlockData | list[BlockData] | None = None,
|
||||
test_mock: dict[str, Any] | None = None,
|
||||
test_credentials: Optional[Credentials] = None,
|
||||
disabled: bool = False,
|
||||
static_output: bool = False,
|
||||
ui_type: BlockUIType = BlockUIType.STANDARD,
|
||||
@@ -170,6 +225,7 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
|
||||
self.test_input = test_input
|
||||
self.test_output = test_output
|
||||
self.test_mock = test_mock
|
||||
self.test_credentials = test_credentials
|
||||
self.description = description
|
||||
self.categories = categories or set()
|
||||
self.contributors = contributors or set()
|
||||
@@ -178,7 +234,7 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
|
||||
self.ui_type = ui_type
|
||||
|
||||
@abstractmethod
|
||||
def run(self, input_data: BlockSchemaInputType) -> BlockOutput:
|
||||
def run(self, input_data: BlockSchemaInputType, **kwargs) -> BlockOutput:
|
||||
"""
|
||||
Run the block with the given input data.
|
||||
Args:
|
||||
@@ -209,13 +265,15 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
|
||||
"uiType": self.ui_type.value,
|
||||
}
|
||||
|
||||
def execute(self, input_data: BlockInput) -> BlockOutput:
|
||||
def execute(self, input_data: BlockInput, **kwargs) -> BlockOutput:
|
||||
if error := self.input_schema.validate_data(input_data):
|
||||
raise ValueError(
|
||||
f"Unable to execute block with invalid input data: {error}"
|
||||
)
|
||||
|
||||
for output_name, output_data in self.run(self.input_schema(**input_data)):
|
||||
for output_name, output_data in self.run(
|
||||
self.input_schema(**input_data), **kwargs
|
||||
):
|
||||
if error := self.output_schema.validate_field(output_name, output_data):
|
||||
raise ValueError(f"Block produced an invalid output data: {error}")
|
||||
yield output_name, output_data
|
||||
|
||||
@@ -3,6 +3,7 @@ from datetime import datetime, timezone
|
||||
from multiprocessing import Manager
|
||||
from typing import Any, Generic, TypeVar
|
||||
|
||||
from autogpt_libs.supabase_integration_credentials_store.types import Credentials
|
||||
from prisma.enums import AgentExecutionStatus
|
||||
from prisma.models import (
|
||||
AgentGraphExecution,
|
||||
@@ -25,6 +26,7 @@ class GraphExecution(BaseModel):
|
||||
graph_exec_id: str
|
||||
graph_id: str
|
||||
start_node_execs: list["NodeExecution"]
|
||||
node_input_credentials: dict[str, Credentials] # dict[node_id, Credentials]
|
||||
|
||||
|
||||
class NodeExecution(BaseModel):
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any, Callable, ClassVar, Optional, TypeVar
|
||||
from typing import Any, Callable, ClassVar, Generic, Optional, TypeVar
|
||||
|
||||
from autogpt_libs.supabase_integration_credentials_store.types import CredentialsType
|
||||
from pydantic import BaseModel, Field, GetCoreSchemaHandler
|
||||
from pydantic_core import (
|
||||
CoreSchema,
|
||||
@@ -136,5 +137,50 @@ def SchemaField(
|
||||
)
|
||||
|
||||
|
||||
CP = TypeVar("CP", bound=str)
|
||||
CT = TypeVar("CT", bound=CredentialsType)
|
||||
|
||||
|
||||
CREDENTIALS_FIELD_NAME = "credentials"
|
||||
|
||||
|
||||
class CredentialsMetaInput(BaseModel, Generic[CP, CT]):
|
||||
id: str
|
||||
title: Optional[str] = None
|
||||
provider: CP
|
||||
type: CT
|
||||
|
||||
|
||||
def CredentialsField(
|
||||
provider: CP,
|
||||
supported_credential_types: set[CT],
|
||||
required_scopes: set[str] = set(),
|
||||
*,
|
||||
title: Optional[str] = None,
|
||||
description: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> CredentialsMetaInput[CP, CT]:
|
||||
"""
|
||||
`CredentialsField` must and can only be used on fields named `credentials`.
|
||||
This is enforced by the `BlockSchema` base class.
|
||||
"""
|
||||
json_extra = {
|
||||
k: v
|
||||
for k, v in {
|
||||
"credentials_provider": provider,
|
||||
"credentials_scopes": list(required_scopes) or None, # omit if empty
|
||||
"credentials_types": list(supported_credential_types),
|
||||
}.items()
|
||||
if v is not None
|
||||
}
|
||||
|
||||
return Field(
|
||||
title=title,
|
||||
description=description,
|
||||
json_schema_extra=json_extra,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
class ContributorDetails(BaseModel):
|
||||
name: str = Field(title="Name", description="The name of the contributor.")
|
||||
|
||||
@@ -9,7 +9,10 @@ import threading
|
||||
from concurrent.futures import Future, ProcessPoolExecutor
|
||||
from contextlib import contextmanager
|
||||
from multiprocessing.pool import AsyncResult, Pool
|
||||
from typing import TYPE_CHECKING, Any, Coroutine, Generator, TypeVar
|
||||
from typing import TYPE_CHECKING, Any, Coroutine, Generator, TypeVar, cast
|
||||
|
||||
from autogpt_libs.supabase_integration_credentials_store.types import Credentials
|
||||
from pydantic import BaseModel
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from backend.server.rest_api import AgentServer
|
||||
@@ -37,6 +40,7 @@ from backend.data.execution import (
|
||||
upsert_execution_output,
|
||||
)
|
||||
from backend.data.graph import Graph, Link, Node, get_graph, get_node
|
||||
from backend.data.model import CREDENTIALS_FIELD_NAME, CredentialsMetaInput
|
||||
from backend.util import json
|
||||
from backend.util.decorator import error_logged, time_measured
|
||||
from backend.util.logging import configure_logging
|
||||
@@ -100,6 +104,7 @@ def execute_node(
|
||||
loop: asyncio.AbstractEventLoop,
|
||||
api_client: "AgentServer",
|
||||
data: NodeExecution,
|
||||
input_credentials: Credentials | None = None,
|
||||
execution_stats: dict[str, Any] | None = None,
|
||||
) -> ExecutionStream:
|
||||
"""
|
||||
@@ -159,13 +164,19 @@ def execute_node(
|
||||
update_execution(ExecutionStatus.RUNNING)
|
||||
user_credit = get_user_credit_model()
|
||||
|
||||
extra_exec_kwargs = {}
|
||||
if input_credentials:
|
||||
extra_exec_kwargs["credentials"] = input_credentials
|
||||
|
||||
output_size = 0
|
||||
try:
|
||||
credit = wait(user_credit.get_or_refill_credit(user_id))
|
||||
if credit < 0:
|
||||
raise ValueError(f"Insufficient credit: {credit}")
|
||||
|
||||
for output_name, output_data in node_block.execute(input_data):
|
||||
for output_name, output_data in node_block.execute(
|
||||
input_data, **extra_exec_kwargs
|
||||
):
|
||||
output_size += len(json.dumps(output_data))
|
||||
log_metadata.info("Node produced output", output_name=output_data)
|
||||
wait(upsert_execution_output(node_exec_id, output_name, output_data))
|
||||
@@ -460,7 +471,10 @@ class Executor:
|
||||
@classmethod
|
||||
@error_logged
|
||||
def on_node_execution(
|
||||
cls, q: ExecutionQueue[NodeExecution], node_exec: NodeExecution
|
||||
cls,
|
||||
q: ExecutionQueue[NodeExecution],
|
||||
node_exec: NodeExecution,
|
||||
input_credentials: Credentials | None,
|
||||
):
|
||||
log_metadata = LogMetadata(
|
||||
user_id=node_exec.user_id,
|
||||
@@ -473,7 +487,7 @@ class Executor:
|
||||
|
||||
execution_stats = {}
|
||||
timing_info, _ = cls._on_node_execution(
|
||||
q, node_exec, log_metadata, execution_stats
|
||||
q, node_exec, input_credentials, log_metadata, execution_stats
|
||||
)
|
||||
execution_stats["walltime"] = timing_info.wall_time
|
||||
execution_stats["cputime"] = timing_info.cpu_time
|
||||
@@ -488,13 +502,14 @@ class Executor:
|
||||
cls,
|
||||
q: ExecutionQueue[NodeExecution],
|
||||
node_exec: NodeExecution,
|
||||
input_credentials: Credentials | None,
|
||||
log_metadata: LogMetadata,
|
||||
stats: dict[str, Any] | None = None,
|
||||
):
|
||||
try:
|
||||
log_metadata.info(f"Start node execution {node_exec.node_exec_id}")
|
||||
for execution in execute_node(
|
||||
cls.loop, cls.agent_server_client, node_exec, stats
|
||||
cls.loop, cls.agent_server_client, node_exec, input_credentials, stats
|
||||
):
|
||||
q.add(execution)
|
||||
log_metadata.info(f"Finished node execution {node_exec.node_exec_id}")
|
||||
@@ -624,7 +639,11 @@ class Executor:
|
||||
)
|
||||
running_executions[exec_data.node_id] = cls.executor.apply_async(
|
||||
cls.on_node_execution,
|
||||
(queue, exec_data),
|
||||
(
|
||||
queue,
|
||||
exec_data,
|
||||
graph_exec.node_input_credentials.get(exec_data.node_id),
|
||||
),
|
||||
callback=make_exec_callback(exec_data),
|
||||
)
|
||||
|
||||
@@ -660,11 +679,17 @@ class ExecutionManager(AppService):
|
||||
def __init__(self):
|
||||
super().__init__(port=Config().execution_manager_port)
|
||||
self.use_db = True
|
||||
self.use_supabase = True
|
||||
self.pool_size = Config().num_graph_workers
|
||||
self.queue = ExecutionQueue[GraphExecution]()
|
||||
self.active_graph_runs: dict[str, tuple[Future, threading.Event]] = {}
|
||||
|
||||
def run_service(self):
|
||||
from autogpt_libs.supabase_integration_credentials_store import (
|
||||
SupabaseIntegrationCredentialsStore,
|
||||
)
|
||||
|
||||
self.credentials_store = SupabaseIntegrationCredentialsStore(self.supabase)
|
||||
self.executor = ProcessPoolExecutor(
|
||||
max_workers=self.pool_size,
|
||||
initializer=Executor.on_graph_executor_start,
|
||||
@@ -705,7 +730,10 @@ class ExecutionManager(AppService):
|
||||
graph: Graph | None = self.run_and_wait(get_graph(graph_id, user_id=user_id))
|
||||
if not graph:
|
||||
raise Exception(f"Graph #{graph_id} not found.")
|
||||
|
||||
graph.validate_graph(for_run=True)
|
||||
node_input_credentials = self._get_node_input_credentials(graph, user_id)
|
||||
|
||||
nodes_input = []
|
||||
for node in graph.starting_nodes:
|
||||
input_data = {}
|
||||
@@ -753,6 +781,7 @@ class ExecutionManager(AppService):
|
||||
graph_id=graph_id,
|
||||
graph_exec_id=graph_exec_id,
|
||||
start_node_execs=starting_node_execs,
|
||||
node_input_credentials=node_input_credentials,
|
||||
)
|
||||
self.queue.add(graph_exec)
|
||||
|
||||
@@ -799,6 +828,58 @@ class ExecutionManager(AppService):
|
||||
)
|
||||
self.agent_server_client.send_execution_update(exec_update.model_dump())
|
||||
|
||||
def _get_node_input_credentials(
|
||||
self, graph: Graph, user_id: str
|
||||
) -> dict[str, Credentials]:
|
||||
"""Gets all credentials for all nodes of the graph"""
|
||||
|
||||
node_credentials: dict[str, Credentials] = {}
|
||||
|
||||
for node in graph.nodes:
|
||||
block = get_block(node.block_id)
|
||||
if not block:
|
||||
raise ValueError(f"Unknown block {node.block_id} for node #{node.id}")
|
||||
|
||||
# Find any fields of type CredentialsMetaInput
|
||||
model_fields = cast(type[BaseModel], block.input_schema).model_fields
|
||||
if CREDENTIALS_FIELD_NAME not in model_fields:
|
||||
continue
|
||||
|
||||
field = model_fields[CREDENTIALS_FIELD_NAME]
|
||||
|
||||
# The BlockSchema class enforces that a `credentials` field is always a
|
||||
# `CredentialsMetaInput`, so we can safely assume this here.
|
||||
credentials_meta_type = cast(CredentialsMetaInput, field.annotation)
|
||||
credentials_meta = credentials_meta_type.model_validate(
|
||||
node.input_default[CREDENTIALS_FIELD_NAME]
|
||||
)
|
||||
# Fetch the corresponding Credentials and perform sanity checks
|
||||
credentials = self.credentials_store.get_creds_by_id(
|
||||
user_id, credentials_meta.id
|
||||
)
|
||||
if not credentials:
|
||||
raise ValueError(
|
||||
f"Unknown credentials #{credentials_meta.id} "
|
||||
f"for node #{node.id}"
|
||||
)
|
||||
if (
|
||||
credentials.provider != credentials_meta.provider
|
||||
or credentials.type != credentials_meta.type
|
||||
):
|
||||
logger.warning(
|
||||
f"Invalid credentials #{credentials.id} for node #{node.id}: "
|
||||
"type/provider mismatch: "
|
||||
f"{credentials_meta.type}<>{credentials.type};"
|
||||
f"{credentials_meta.provider}<>{credentials.provider}"
|
||||
)
|
||||
raise ValueError(
|
||||
f"Invalid credentials #{credentials.id} for node #{node.id}: "
|
||||
"type/provider mismatch"
|
||||
)
|
||||
node_credentials[node.id] = credentials
|
||||
|
||||
return node_credentials
|
||||
|
||||
|
||||
def llprint(message: str):
|
||||
"""
|
||||
|
||||
@@ -1,15 +1,26 @@
|
||||
import logging
|
||||
from typing import Annotated, Literal
|
||||
from typing import Annotated
|
||||
|
||||
from autogpt_libs.supabase_integration_credentials_store import (
|
||||
SupabaseIntegrationCredentialsStore,
|
||||
)
|
||||
from autogpt_libs.supabase_integration_credentials_store.types import (
|
||||
APIKeyCredentials,
|
||||
Credentials,
|
||||
CredentialsType,
|
||||
OAuth2Credentials,
|
||||
)
|
||||
from fastapi import APIRouter, Body, Depends, HTTPException, Path, Query, Request
|
||||
from pydantic import BaseModel
|
||||
from fastapi import (
|
||||
APIRouter,
|
||||
Body,
|
||||
Depends,
|
||||
HTTPException,
|
||||
Path,
|
||||
Query,
|
||||
Request,
|
||||
Response,
|
||||
)
|
||||
from pydantic import BaseModel, SecretStr
|
||||
from supabase import Client
|
||||
|
||||
from backend.integrations.oauth import HANDLERS_BY_NAME, BaseOAuthHandler
|
||||
@@ -28,6 +39,7 @@ def get_store(supabase: Client = Depends(get_supabase)):
|
||||
|
||||
class LoginResponse(BaseModel):
|
||||
login_url: str
|
||||
state_token: str
|
||||
|
||||
|
||||
@router.get("/{provider}/login")
|
||||
@@ -43,17 +55,17 @@ async def login(
|
||||
handler = _get_provider_oauth_handler(request, provider)
|
||||
|
||||
# Generate and store a secure random state token
|
||||
state = await store.store_state_token(user_id, provider)
|
||||
state_token = await store.store_state_token(user_id, provider)
|
||||
|
||||
requested_scopes = scopes.split(",") if scopes else []
|
||||
login_url = handler.get_login_url(requested_scopes, state)
|
||||
login_url = handler.get_login_url(requested_scopes, state_token)
|
||||
|
||||
return LoginResponse(login_url=login_url)
|
||||
return LoginResponse(login_url=login_url, state_token=state_token)
|
||||
|
||||
|
||||
class CredentialsMetaResponse(BaseModel):
|
||||
id: str
|
||||
type: Literal["oauth2", "api_key"]
|
||||
type: CredentialsType
|
||||
title: str | None
|
||||
scopes: list[str] | None
|
||||
username: str | None
|
||||
@@ -127,6 +139,52 @@ async def get_credential(
|
||||
return credential
|
||||
|
||||
|
||||
@router.post("/{provider}/credentials", status_code=201)
|
||||
async def create_api_key_credentials(
|
||||
store: Annotated[SupabaseIntegrationCredentialsStore, Depends(get_store)],
|
||||
user_id: Annotated[str, Depends(get_user_id)],
|
||||
provider: Annotated[str, Path(title="The provider to create credentials for")],
|
||||
api_key: Annotated[str, Body(title="The API key to store")],
|
||||
title: Annotated[str, Body(title="Optional title for the credentials")],
|
||||
expires_at: Annotated[
|
||||
int | None, Body(title="Unix timestamp when the key expires")
|
||||
] = None,
|
||||
) -> APIKeyCredentials:
|
||||
new_credentials = APIKeyCredentials(
|
||||
provider=provider,
|
||||
api_key=SecretStr(api_key),
|
||||
title=title,
|
||||
expires_at=expires_at,
|
||||
)
|
||||
|
||||
try:
|
||||
store.add_creds(user_id, new_credentials)
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to store credentials: {str(e)}"
|
||||
)
|
||||
return new_credentials
|
||||
|
||||
|
||||
@router.delete("/{provider}/credentials/{cred_id}", status_code=204)
|
||||
async def delete_credential(
|
||||
provider: Annotated[str, Path(title="The provider to delete credentials for")],
|
||||
cred_id: Annotated[str, Path(title="The ID of the credentials to delete")],
|
||||
user_id: Annotated[str, Depends(get_user_id)],
|
||||
store: Annotated[SupabaseIntegrationCredentialsStore, Depends(get_store)],
|
||||
):
|
||||
creds = store.get_creds_by_id(user_id, cred_id)
|
||||
if not creds:
|
||||
raise HTTPException(status_code=404, detail="Credentials not found")
|
||||
if creds.provider != provider:
|
||||
raise HTTPException(
|
||||
status_code=404, detail="Credentials do not match the specified provider"
|
||||
)
|
||||
|
||||
store.delete_creds_by_id(user_id, cred_id)
|
||||
return Response(status_code=204)
|
||||
|
||||
|
||||
# -------- UTILITIES --------- #
|
||||
|
||||
|
||||
@@ -145,8 +203,9 @@ def _get_provider_oauth_handler(req: Request, provider_name: str) -> BaseOAuthHa
|
||||
)
|
||||
|
||||
handler_class = HANDLERS_BY_NAME[provider_name]
|
||||
frontend_base_url = settings.config.frontend_base_url or str(req.base_url)
|
||||
return handler_class(
|
||||
client_id=client_id,
|
||||
client_secret=client_secret,
|
||||
redirect_uri=str(req.url_for("callback", provider=provider_name)),
|
||||
redirect_uri=f"{frontend_base_url}/auth/integrations/oauth_callback",
|
||||
)
|
||||
|
||||
@@ -20,4 +20,6 @@ def get_user_id(payload: dict = Depends(auth_middleware)) -> str:
|
||||
|
||||
|
||||
def get_supabase() -> Client:
|
||||
return create_client(settings.secrets.supabase_url, settings.secrets.supabase_key)
|
||||
return create_client(
|
||||
settings.secrets.supabase_url, settings.secrets.supabase_service_role_key
|
||||
)
|
||||
|
||||
@@ -13,7 +13,7 @@ from backend.data import db
|
||||
from backend.data.queue import AsyncEventQueue, AsyncRedisEventQueue
|
||||
from backend.util.process import AppProcess
|
||||
from backend.util.retry import conn_retry
|
||||
from backend.util.settings import Config
|
||||
from backend.util.settings import Config, Secrets
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
T = TypeVar("T")
|
||||
@@ -48,6 +48,7 @@ class AppService(AppProcess):
|
||||
event_queue: AsyncEventQueue = AsyncRedisEventQueue()
|
||||
use_db: bool = False
|
||||
use_redis: bool = False
|
||||
use_supabase: bool = False
|
||||
|
||||
def __init__(self, port):
|
||||
self.port = port
|
||||
@@ -76,6 +77,13 @@ class AppService(AppProcess):
|
||||
self.shared_event_loop.run_until_complete(db.connect())
|
||||
if self.use_redis:
|
||||
self.shared_event_loop.run_until_complete(self.event_queue.connect())
|
||||
if self.use_supabase:
|
||||
from supabase import create_client
|
||||
|
||||
secrets = Secrets()
|
||||
self.supabase = create_client(
|
||||
secrets.supabase_url, secrets.supabase_service_role_key
|
||||
)
|
||||
|
||||
# Initialize the async loop.
|
||||
async_thread = threading.Thread(target=self.__start_async_loop)
|
||||
|
||||
@@ -115,6 +115,12 @@ class Config(UpdateTrackingModel["Config"], BaseSettings):
|
||||
description="The port for agent server API to run on",
|
||||
)
|
||||
|
||||
frontend_base_url: str = Field(
|
||||
default="",
|
||||
description="Can be used to explicitly set the base URL for the frontend. "
|
||||
"This value is then used to generate redirect URLs for OAuth flows.",
|
||||
)
|
||||
|
||||
backend_cors_allow_origins: List[str] = Field(default_factory=list)
|
||||
|
||||
@field_validator("backend_cors_allow_origins")
|
||||
@@ -166,7 +172,9 @@ class Secrets(UpdateTrackingModel["Secrets"], BaseSettings):
|
||||
"""Secrets for the server."""
|
||||
|
||||
supabase_url: str = Field(default="", description="Supabase URL")
|
||||
supabase_key: str = Field(default="", description="Supabase key")
|
||||
supabase_service_role_key: str = Field(
|
||||
default="", description="Supabase service role key"
|
||||
)
|
||||
|
||||
# OAuth server credentials for integrations
|
||||
github_client_id: str = Field(default="", description="GitHub OAuth client ID")
|
||||
|
||||
@@ -4,6 +4,7 @@ import time
|
||||
from backend.data import db
|
||||
from backend.data.block import Block, initialize_blocks
|
||||
from backend.data.execution import ExecutionResult, ExecutionStatus
|
||||
from backend.data.model import CREDENTIALS_FIELD_NAME
|
||||
from backend.data.queue import AsyncEventQueue
|
||||
from backend.data.user import create_default_user
|
||||
from backend.executor import ExecutionManager, ExecutionScheduler
|
||||
@@ -130,10 +131,19 @@ def execute_block_test(block: Block):
|
||||
else:
|
||||
log(f"{prefix} mock {mock_name} not found in block")
|
||||
|
||||
extra_exec_kwargs = {}
|
||||
|
||||
if CREDENTIALS_FIELD_NAME in block.input_schema.model_fields:
|
||||
if not block.test_credentials:
|
||||
raise ValueError(
|
||||
f"{prefix} requires credentials but has no test_credentials"
|
||||
)
|
||||
extra_exec_kwargs[CREDENTIALS_FIELD_NAME] = block.test_credentials
|
||||
|
||||
for input_data in block.test_input:
|
||||
log(f"{prefix} in: {input_data}")
|
||||
|
||||
for output_name, output_data in block.execute(input_data):
|
||||
for output_name, output_data in block.execute(input_data, **extra_exec_kwargs):
|
||||
if output_index >= len(block.test_output):
|
||||
raise ValueError(f"{prefix} produced output more than expected")
|
||||
ex_output_name, ex_output_data = block.test_output[output_index]
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
from backend.data.block import get_blocks
|
||||
import pytest
|
||||
|
||||
from backend.data.block import Block, get_blocks
|
||||
from backend.util.test import execute_block_test
|
||||
|
||||
|
||||
def test_available_blocks():
|
||||
for block in get_blocks().values():
|
||||
execute_block_test(type(block)())
|
||||
@pytest.mark.parametrize("block", get_blocks().values(), ids=lambda b: b.name)
|
||||
def test_available_blocks(block: Block):
|
||||
execute_block_test(type(block)())
|
||||
|
||||
@@ -58,7 +58,7 @@ services:
|
||||
environment:
|
||||
- SUPABASE_URL=http://kong:8000
|
||||
- SUPABASE_JWT_SECRET=your-super-secret-jwt-token-with-at-least-32-characters-long
|
||||
- SUPABASE_ANON_KEY=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyAgCiAgICAicm9sZSI6ICJhbm9uIiwKICAgICJpc3MiOiAic3VwYWJhc2UtZGVtbyIsCiAgICAiaWF0IjogMTY0MTc2OTIwMCwKICAgICJleHAiOiAxNzk5NTM1NjAwCn0.dc_X5iR_VP_qT0zsiyj_I_OZ2T9FtRU2BBNWN8Bu4GE
|
||||
- SUPABASE_SERVICE_ROLE_KEY=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyAgCiAgICAicm9sZSI6ICJzZXJ2aWNlX3JvbGUiLAogICAgImlzcyI6ICJzdXBhYmFzZS1kZW1vIiwKICAgICJpYXQiOiAxNjQxNzY5MjAwLAogICAgImV4cCI6IDE3OTk1MzU2MDAKfQ.DaYlNEoUrrEn2Ig7tqibS-PHK5vgusbcbo7X36XVt4Q
|
||||
- DATABASE_URL=postgresql://postgres:your-super-secret-and-long-postgres-password@db:5432/postgres?connect_timeout=60&schema=platform
|
||||
- REDIS_HOST=redis
|
||||
- REDIS_PORT=6379
|
||||
@@ -66,6 +66,7 @@ services:
|
||||
- ENABLE_AUTH=true
|
||||
- PYRO_HOST=0.0.0.0
|
||||
- EXECUTIONMANAGER_HOST=executor
|
||||
- FRONTEND_BASE_URL=http://localhost:3000
|
||||
- BACKEND_CORS_ALLOW_ORIGINS=["http://localhost:3000"]
|
||||
ports:
|
||||
- "8006:8006"
|
||||
@@ -92,9 +93,9 @@ services:
|
||||
migrate:
|
||||
condition: service_completed_successfully
|
||||
environment:
|
||||
- NEXT_PUBLIC_SUPABASE_URL=http://kong:8000
|
||||
- SUPABASE_URL=http://kong:8000
|
||||
- SUPABASE_JWT_SECRET=your-super-secret-jwt-token-with-at-least-32-characters-long
|
||||
- SUPABASE_ANON_KEY=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyAgCiAgICAicm9sZSI6ICJhbm9uIiwKICAgICJpc3MiOiAic3VwYWJhc2UtZGVtbyIsCiAgICAiaWF0IjogMTY0MTc2OTIwMCwKICAgICJleHAiOiAxNzk5NTM1NjAwCn0.dc_X5iR_VP_qT0zsiyj_I_OZ2T9FtRU2BBNWN8Bu4GE
|
||||
- SUPABASE_SERVICE_ROLE_KEY=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyAgCiAgICAicm9sZSI6ICJzZXJ2aWNlX3JvbGUiLAogICAgImlzcyI6ICJzdXBhYmFzZS1kZW1vIiwKICAgICJpYXQiOiAxNjQxNzY5MjAwLAogICAgImV4cCI6IDE3OTk1MzU2MDAKfQ.DaYlNEoUrrEn2Ig7tqibS-PHK5vgusbcbo7X36XVt4Q
|
||||
- DATABASE_URL=postgresql://postgres:your-super-secret-and-long-postgres-password@db:5432/postgres?connect_timeout=60&schema=platform
|
||||
- REDIS_HOST=redis
|
||||
- REDIS_PORT=6379
|
||||
@@ -126,9 +127,7 @@ services:
|
||||
migrate:
|
||||
condition: service_completed_successfully
|
||||
environment:
|
||||
- SUPABASE_URL=http://kong:8000
|
||||
- SUPABASE_JWT_SECRET=your-super-secret-jwt-token-with-at-least-32-characters-long
|
||||
- SUPABASE_ANON_KEY=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyAgCiAgICAicm9sZSI6ICJhbm9uIiwKICAgICJpc3MiOiAic3VwYWJhc2UtZGVtbyIsCiAgICAiaWF0IjogMTY0MTc2OTIwMCwKICAgICJleHAiOiAxNzk5NTM1NjAwCn0.dc_X5iR_VP_qT0zsiyj_I_OZ2T9FtRU2BBNWN8Bu4GE
|
||||
- DATABASE_URL=postgresql://postgres:your-super-secret-and-long-postgres-password@db:5432/postgres?connect_timeout=60&schema=platform
|
||||
- REDIS_HOST=redis
|
||||
- REDIS_PORT=6379
|
||||
|
||||
@@ -0,0 +1,38 @@
|
||||
import { OAuthPopupResultMessage } from "@/components/integrations/credentials-input";
|
||||
import { NextResponse } from "next/server";
|
||||
|
||||
// This route is intended to be used as the callback for integration OAuth flows,
|
||||
// controlled by the CredentialsInput component. The CredentialsInput opens the login
|
||||
// page in a pop-up window, which then redirects to this route to close the loop.
|
||||
export async function GET(request: Request) {
|
||||
const { searchParams, origin } = new URL(request.url);
|
||||
const code = searchParams.get("code");
|
||||
const state = searchParams.get("state");
|
||||
|
||||
// Send message from popup window to host window
|
||||
const message: OAuthPopupResultMessage =
|
||||
code && state
|
||||
? { message_type: "oauth_popup_result", success: true, code, state }
|
||||
: {
|
||||
message_type: "oauth_popup_result",
|
||||
success: false,
|
||||
message: `Incomplete query: ${searchParams.toString()}`,
|
||||
};
|
||||
|
||||
// Return a response with the message as JSON and a script to close the window
|
||||
return new NextResponse(
|
||||
`
|
||||
<html>
|
||||
<body>
|
||||
<script>
|
||||
window.postMessage(${JSON.stringify(message)});
|
||||
window.close();
|
||||
</script>
|
||||
</body>
|
||||
</html>
|
||||
`,
|
||||
{
|
||||
headers: { "Content-Type": "text/html" },
|
||||
},
|
||||
);
|
||||
}
|
||||
@@ -5,12 +5,15 @@ import { ThemeProvider as NextThemesProvider } from "next-themes";
|
||||
import { ThemeProviderProps } from "next-themes/dist/types";
|
||||
import { TooltipProvider } from "@/components/ui/tooltip";
|
||||
import SupabaseProvider from "@/components/SupabaseProvider";
|
||||
import CredentialsProvider from "@/components/integrations/credentials-provider";
|
||||
|
||||
export function Providers({ children, ...props }: ThemeProviderProps) {
|
||||
return (
|
||||
<NextThemesProvider {...props}>
|
||||
<SupabaseProvider>
|
||||
<TooltipProvider>{children}</TooltipProvider>
|
||||
<CredentialsProvider>
|
||||
<TooltipProvider>{children}</TooltipProvider>
|
||||
</CredentialsProvider>
|
||||
</SupabaseProvider>
|
||||
</NextThemesProvider>
|
||||
);
|
||||
|
||||
@@ -255,13 +255,19 @@ export function CustomNode({ data, id, width, height }: NodeProps<CustomNode>) {
|
||||
return (
|
||||
(isRequired || isAdvancedOpen || isConnected || !isAdvanced) && (
|
||||
<div key={propKey} onMouseOver={() => {}}>
|
||||
<NodeHandle
|
||||
keyName={propKey}
|
||||
isConnected={isConnected}
|
||||
isRequired={isRequired}
|
||||
schema={propSchema}
|
||||
side="left"
|
||||
/>
|
||||
{"credentials_provider" in propSchema ? (
|
||||
<span className="text-m green -mb-1 text-gray-900">
|
||||
Credentials
|
||||
</span>
|
||||
) : (
|
||||
<NodeHandle
|
||||
keyName={propKey}
|
||||
isConnected={isConnected}
|
||||
isRequired={isRequired}
|
||||
schema={propSchema}
|
||||
side="left"
|
||||
/>
|
||||
)}
|
||||
{!isConnected && (
|
||||
<NodeGenericInputField
|
||||
className="mb-2 mt-1"
|
||||
|
||||
@@ -0,0 +1,419 @@
|
||||
import { z } from "zod";
|
||||
import { cn } from "@/lib/utils";
|
||||
import { useForm } from "react-hook-form";
|
||||
import { Input } from "@/components/ui/input";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import useCredentials from "@/hooks/useCredentials";
|
||||
import { zodResolver } from "@hookform/resolvers/zod";
|
||||
import AutoGPTServerAPI from "@/lib/autogpt-server-api";
|
||||
import { NotionLogoIcon } from "@radix-ui/react-icons";
|
||||
import { FaGithub, FaGoogle } from "react-icons/fa";
|
||||
import { FC, useMemo, useState } from "react";
|
||||
import {
|
||||
APIKeyCredentials,
|
||||
CredentialsMetaInput,
|
||||
} from "@/lib/autogpt-server-api/types";
|
||||
import {
|
||||
IconKey,
|
||||
IconKeyPlus,
|
||||
IconUser,
|
||||
IconUserPlus,
|
||||
} from "@/components/ui/icons";
|
||||
import {
|
||||
Dialog,
|
||||
DialogContent,
|
||||
DialogDescription,
|
||||
DialogHeader,
|
||||
DialogTitle,
|
||||
} from "@/components/ui/dialog";
|
||||
import {
|
||||
Form,
|
||||
FormControl,
|
||||
FormDescription,
|
||||
FormField,
|
||||
FormItem,
|
||||
FormLabel,
|
||||
FormMessage,
|
||||
} from "@/components/ui/form";
|
||||
import {
|
||||
Select,
|
||||
SelectContent,
|
||||
SelectItem,
|
||||
SelectSeparator,
|
||||
SelectTrigger,
|
||||
SelectValue,
|
||||
} from "@/components/ui/select";
|
||||
|
||||
const providerIcons: Record<string, React.FC<{ className?: string }>> = {
|
||||
github: FaGithub,
|
||||
google: FaGoogle,
|
||||
notion: NotionLogoIcon,
|
||||
};
|
||||
|
||||
export type OAuthPopupResultMessage = { message_type: "oauth_popup_result" } & (
|
||||
| {
|
||||
success: true;
|
||||
code: string;
|
||||
state: string;
|
||||
}
|
||||
| {
|
||||
success: false;
|
||||
message: string;
|
||||
}
|
||||
);
|
||||
|
||||
export const CredentialsInput: FC<{
|
||||
className?: string;
|
||||
selectedCredentials?: CredentialsMetaInput;
|
||||
onSelectCredentials: (newValue: CredentialsMetaInput) => void;
|
||||
}> = ({ className, selectedCredentials, onSelectCredentials }) => {
|
||||
const api = useMemo(() => new AutoGPTServerAPI(), []);
|
||||
const credentials = useCredentials();
|
||||
const [isAPICredentialsModalOpen, setAPICredentialsModalOpen] =
|
||||
useState(false);
|
||||
const [isOAuth2FlowInProgress, setOAuth2FlowInProgress] = useState(false);
|
||||
const [oAuthPopupController, setOAuthPopupController] =
|
||||
useState<AbortController | null>(null);
|
||||
|
||||
if (!credentials) {
|
||||
return null;
|
||||
}
|
||||
|
||||
if (credentials.isLoading) {
|
||||
return <div>Loading...</div>;
|
||||
}
|
||||
|
||||
const {
|
||||
schema,
|
||||
provider,
|
||||
providerName,
|
||||
supportsApiKey,
|
||||
supportsOAuth2,
|
||||
savedApiKeys,
|
||||
savedOAuthCredentials,
|
||||
oAuthCallback,
|
||||
} = credentials;
|
||||
|
||||
async function handleOAuthLogin() {
|
||||
const { login_url, state_token } = await api.oAuthLogin(
|
||||
provider,
|
||||
schema.credentials_scopes,
|
||||
);
|
||||
setOAuth2FlowInProgress(true);
|
||||
const popup = window.open(login_url, "_blank", "popup=true");
|
||||
|
||||
const controller = new AbortController();
|
||||
setOAuthPopupController(controller);
|
||||
controller.signal.onabort = () => {
|
||||
setOAuth2FlowInProgress(false);
|
||||
popup?.close();
|
||||
};
|
||||
popup?.addEventListener(
|
||||
"message",
|
||||
async (e: MessageEvent<OAuthPopupResultMessage>) => {
|
||||
if (
|
||||
typeof e.data != "object" ||
|
||||
!(
|
||||
"message_type" in e.data &&
|
||||
e.data.message_type == "oauth_popup_result"
|
||||
)
|
||||
)
|
||||
return;
|
||||
|
||||
if (!e.data.success) {
|
||||
console.error("OAuth flow failed:", e.data.message);
|
||||
return;
|
||||
}
|
||||
|
||||
if (e.data.state !== state_token) return;
|
||||
|
||||
const credentials = await oAuthCallback(e.data.code, e.data.state);
|
||||
onSelectCredentials({
|
||||
id: credentials.id,
|
||||
type: "oauth2",
|
||||
title: credentials.title,
|
||||
provider,
|
||||
});
|
||||
controller.abort("success");
|
||||
},
|
||||
{ signal: controller.signal },
|
||||
);
|
||||
|
||||
setTimeout(
|
||||
() => {
|
||||
controller.abort("timeout");
|
||||
},
|
||||
5 * 60 * 1000,
|
||||
);
|
||||
}
|
||||
|
||||
const ProviderIcon = providerIcons[provider];
|
||||
const modals = (
|
||||
<>
|
||||
{supportsApiKey && (
|
||||
<APIKeyCredentialsModal
|
||||
open={isAPICredentialsModalOpen}
|
||||
onClose={() => setAPICredentialsModalOpen(false)}
|
||||
onCredentialsCreate={(credsMeta) => {
|
||||
onSelectCredentials(credsMeta);
|
||||
setAPICredentialsModalOpen(false);
|
||||
}}
|
||||
/>
|
||||
)}
|
||||
{supportsOAuth2 && (
|
||||
<OAuth2FlowWaitingModal
|
||||
open={isOAuth2FlowInProgress}
|
||||
onClose={() => oAuthPopupController?.abort("canceled")}
|
||||
providerName={providerName}
|
||||
/>
|
||||
)}
|
||||
</>
|
||||
);
|
||||
|
||||
// No saved credentials yet
|
||||
if (savedApiKeys.length === 0 && savedOAuthCredentials.length === 0) {
|
||||
return (
|
||||
<>
|
||||
<div className={cn("flex flex-row space-x-2", className)}>
|
||||
{supportsOAuth2 && (
|
||||
<Button onClick={handleOAuthLogin}>
|
||||
<ProviderIcon className="mr-2 h-4 w-4" />
|
||||
{"Sign in with " + providerName}
|
||||
</Button>
|
||||
)}
|
||||
{supportsApiKey && (
|
||||
<Button onClick={() => setAPICredentialsModalOpen(true)}>
|
||||
<ProviderIcon className="mr-2 h-4 w-4" />
|
||||
Enter API key
|
||||
</Button>
|
||||
)}
|
||||
</div>
|
||||
{modals}
|
||||
</>
|
||||
);
|
||||
}
|
||||
|
||||
function handleValueChange(newValue: string) {
|
||||
if (newValue === "sign-in") {
|
||||
// Trigger OAuth2 sign in flow
|
||||
handleOAuthLogin();
|
||||
} else if (newValue === "add-api-key") {
|
||||
// Open API key dialog
|
||||
setAPICredentialsModalOpen(true);
|
||||
} else {
|
||||
const selectedCreds = savedApiKeys
|
||||
.concat(savedOAuthCredentials)
|
||||
.find((c) => c.id == newValue)!;
|
||||
|
||||
onSelectCredentials({
|
||||
id: selectedCreds.id,
|
||||
type: selectedCreds.type,
|
||||
provider: schema.credentials_provider,
|
||||
// title: customTitle, // TODO: add input for title
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// Saved credentials exist
|
||||
return (
|
||||
<>
|
||||
<Select value={selectedCredentials?.id} onValueChange={handleValueChange}>
|
||||
<SelectTrigger>
|
||||
<SelectValue placeholder={schema.placeholder} />
|
||||
</SelectTrigger>
|
||||
<SelectContent className="nodrag">
|
||||
{savedOAuthCredentials.map((credentials, index) => (
|
||||
<SelectItem key={index} value={credentials.id}>
|
||||
<ProviderIcon className="mr-2 inline h-4 w-4" />
|
||||
{credentials.username}
|
||||
</SelectItem>
|
||||
))}
|
||||
{savedApiKeys.map((credentials, index) => (
|
||||
<SelectItem key={index} value={credentials.id}>
|
||||
<ProviderIcon className="mr-2 inline h-4 w-4" />
|
||||
<IconKey className="mr-1.5 inline" />
|
||||
{credentials.title}
|
||||
</SelectItem>
|
||||
))}
|
||||
<SelectSeparator />
|
||||
{supportsOAuth2 && (
|
||||
<SelectItem value="sign-in">
|
||||
<IconUserPlus className="mr-1.5 inline" />
|
||||
Sign in with {providerName}
|
||||
</SelectItem>
|
||||
)}
|
||||
{supportsApiKey && (
|
||||
<SelectItem value="add-api-key">
|
||||
<IconKeyPlus className="mr-1.5 inline" />
|
||||
Add new API key
|
||||
</SelectItem>
|
||||
)}
|
||||
</SelectContent>
|
||||
</Select>
|
||||
{modals}
|
||||
</>
|
||||
);
|
||||
};
|
||||
|
||||
export const APIKeyCredentialsModal: FC<{
|
||||
open: boolean;
|
||||
onClose: () => void;
|
||||
onCredentialsCreate: (creds: CredentialsMetaInput) => void;
|
||||
}> = ({ open, onClose, onCredentialsCreate }) => {
|
||||
const credentials = useCredentials();
|
||||
|
||||
const formSchema = z.object({
|
||||
apiKey: z.string().min(1, "API Key is required"),
|
||||
title: z.string().min(1, "Name is required"),
|
||||
expiresAt: z.string().optional(),
|
||||
});
|
||||
|
||||
const form = useForm<z.infer<typeof formSchema>>({
|
||||
resolver: zodResolver(formSchema),
|
||||
defaultValues: {
|
||||
apiKey: "",
|
||||
title: "",
|
||||
expiresAt: "",
|
||||
},
|
||||
});
|
||||
|
||||
if (!credentials || credentials.isLoading || !credentials.supportsApiKey) {
|
||||
return null;
|
||||
}
|
||||
|
||||
const { schema, provider, providerName, createAPIKeyCredentials } =
|
||||
credentials;
|
||||
|
||||
async function onSubmit(values: z.infer<typeof formSchema>) {
|
||||
const expiresAt = values.expiresAt
|
||||
? new Date(values.expiresAt).getTime() / 1000
|
||||
: undefined;
|
||||
const newCredentials = await createAPIKeyCredentials({
|
||||
api_key: values.apiKey,
|
||||
title: values.title,
|
||||
expires_at: expiresAt,
|
||||
});
|
||||
onCredentialsCreate({
|
||||
provider,
|
||||
id: newCredentials.id,
|
||||
type: "api_key",
|
||||
title: newCredentials.title,
|
||||
});
|
||||
}
|
||||
|
||||
return (
|
||||
<Dialog
|
||||
open={open}
|
||||
onOpenChange={(open) => {
|
||||
if (!open) onClose();
|
||||
}}
|
||||
>
|
||||
<DialogContent>
|
||||
<DialogHeader>
|
||||
<DialogTitle>Add new API key for {providerName}</DialogTitle>
|
||||
{schema.description && (
|
||||
<DialogDescription>{schema.description}</DialogDescription>
|
||||
)}
|
||||
</DialogHeader>
|
||||
|
||||
<Form {...form}>
|
||||
<form onSubmit={form.handleSubmit(onSubmit)} className="space-y-4">
|
||||
<FormField
|
||||
control={form.control}
|
||||
name="apiKey"
|
||||
render={({ field }) => (
|
||||
<FormItem>
|
||||
<FormLabel>API Key</FormLabel>
|
||||
{schema.credentials_scopes && (
|
||||
<FormDescription>
|
||||
Required scope(s) for this block:{" "}
|
||||
{schema.credentials_scopes?.map((s, i, a) => (
|
||||
<span key={i}>
|
||||
<code>{s}</code>
|
||||
{i < a.length - 1 && ", "}
|
||||
</span>
|
||||
))}
|
||||
</FormDescription>
|
||||
)}
|
||||
<FormControl>
|
||||
<Input
|
||||
type="password"
|
||||
placeholder="Enter API key..."
|
||||
{...field}
|
||||
/>
|
||||
</FormControl>
|
||||
<FormMessage />
|
||||
</FormItem>
|
||||
)}
|
||||
/>
|
||||
<FormField
|
||||
control={form.control}
|
||||
name="title"
|
||||
render={({ field }) => (
|
||||
<FormItem>
|
||||
<FormLabel>Name</FormLabel>
|
||||
<FormControl>
|
||||
<Input
|
||||
type="text"
|
||||
placeholder="Enter a name for this API key..."
|
||||
{...field}
|
||||
/>
|
||||
</FormControl>
|
||||
<FormMessage />
|
||||
</FormItem>
|
||||
)}
|
||||
/>
|
||||
<FormField
|
||||
control={form.control}
|
||||
name="expiresAt"
|
||||
render={({ field }) => (
|
||||
<FormItem>
|
||||
<FormLabel>Expiration Date (Optional)</FormLabel>
|
||||
<FormControl>
|
||||
<Input
|
||||
type="datetime-local"
|
||||
placeholder="Select expiration date..."
|
||||
{...field}
|
||||
/>
|
||||
</FormControl>
|
||||
<FormMessage />
|
||||
</FormItem>
|
||||
)}
|
||||
/>
|
||||
<Button type="submit" className="w-full">
|
||||
Save & use this API key
|
||||
</Button>
|
||||
</form>
|
||||
</Form>
|
||||
</DialogContent>
|
||||
</Dialog>
|
||||
);
|
||||
};
|
||||
|
||||
export const OAuth2FlowWaitingModal: FC<{
|
||||
open: boolean;
|
||||
onClose: () => void;
|
||||
providerName: string;
|
||||
}> = ({ open, onClose, providerName }) => {
|
||||
return (
|
||||
<Dialog
|
||||
open={open}
|
||||
onOpenChange={(open) => {
|
||||
if (!open) onClose();
|
||||
}}
|
||||
>
|
||||
<DialogContent>
|
||||
<DialogHeader>
|
||||
<DialogTitle>
|
||||
Waiting on {providerName} sign-in process...
|
||||
</DialogTitle>
|
||||
<DialogDescription>
|
||||
Complete the sign-in process in the pop-up window.
|
||||
<br />
|
||||
Closing this dialog will cancel the sign-in process.
|
||||
</DialogDescription>
|
||||
</DialogHeader>
|
||||
</DialogContent>
|
||||
</Dialog>
|
||||
);
|
||||
};
|
||||
@@ -0,0 +1,164 @@
|
||||
import AutoGPTServerAPI, {
|
||||
APIKeyCredentials,
|
||||
CredentialsMetaResponse,
|
||||
} from "@/lib/autogpt-server-api";
|
||||
import {
|
||||
createContext,
|
||||
useCallback,
|
||||
useEffect,
|
||||
useMemo,
|
||||
useState,
|
||||
} from "react";
|
||||
|
||||
const CREDENTIALS_PROVIDER_NAMES = ["github", "google", "notion"] as const;
|
||||
|
||||
type CredentialsProviderName = (typeof CREDENTIALS_PROVIDER_NAMES)[number];
|
||||
|
||||
const providerDisplayNames: Record<CredentialsProviderName, string> = {
|
||||
github: "GitHub",
|
||||
google: "Google",
|
||||
notion: "Notion",
|
||||
};
|
||||
|
||||
type APIKeyCredentialsCreatable = Omit<
|
||||
APIKeyCredentials,
|
||||
"id" | "provider" | "type"
|
||||
>;
|
||||
|
||||
export type CredentialsProviderData = {
|
||||
provider: string;
|
||||
providerName: string;
|
||||
savedApiKeys: CredentialsMetaResponse[];
|
||||
savedOAuthCredentials: CredentialsMetaResponse[];
|
||||
oAuthCallback: (
|
||||
code: string,
|
||||
state_token: string,
|
||||
) => Promise<CredentialsMetaResponse>;
|
||||
createAPIKeyCredentials: (
|
||||
credentials: APIKeyCredentialsCreatable,
|
||||
) => Promise<CredentialsMetaResponse>;
|
||||
};
|
||||
|
||||
export type CredentialsProvidersContextType = {
|
||||
[key in CredentialsProviderName]?: CredentialsProviderData;
|
||||
};
|
||||
|
||||
export const CredentialsProvidersContext =
|
||||
createContext<CredentialsProvidersContextType | null>(null);
|
||||
|
||||
export default function CredentialsProvider({
|
||||
children,
|
||||
}: {
|
||||
children: React.ReactNode;
|
||||
}) {
|
||||
const [providers, setProviders] =
|
||||
useState<CredentialsProvidersContextType | null>(null);
|
||||
const api = useMemo(() => new AutoGPTServerAPI(), []);
|
||||
|
||||
const addCredentials = useCallback(
|
||||
(
|
||||
provider: CredentialsProviderName,
|
||||
credentials: CredentialsMetaResponse,
|
||||
) => {
|
||||
setProviders((prev) => {
|
||||
if (!prev || !prev[provider]) return prev;
|
||||
|
||||
const updatedProvider = { ...prev[provider] };
|
||||
|
||||
if (credentials.type === "api_key") {
|
||||
updatedProvider.savedApiKeys = [
|
||||
...updatedProvider.savedApiKeys,
|
||||
credentials,
|
||||
];
|
||||
} else if (credentials.type === "oauth2") {
|
||||
updatedProvider.savedOAuthCredentials = [
|
||||
...updatedProvider.savedOAuthCredentials,
|
||||
credentials,
|
||||
];
|
||||
}
|
||||
|
||||
return {
|
||||
...prev,
|
||||
[provider]: updatedProvider,
|
||||
};
|
||||
});
|
||||
},
|
||||
[setProviders],
|
||||
);
|
||||
|
||||
/** Wraps `AutoGPTServerAPI.oAuthCallback`, and adds the result to the internal credentials store. */
|
||||
const oAuthCallback = useCallback(
|
||||
async (
|
||||
provider: CredentialsProviderName,
|
||||
code: string,
|
||||
state_token: string,
|
||||
): Promise<CredentialsMetaResponse> => {
|
||||
const credsMeta = await api.oAuthCallback(provider, code, state_token);
|
||||
addCredentials(provider, credsMeta);
|
||||
return credsMeta;
|
||||
},
|
||||
[api, addCredentials],
|
||||
);
|
||||
|
||||
/** Wraps `AutoGPTServerAPI.createAPIKeyCredentials`, and adds the result to the internal credentials store. */
|
||||
const createAPIKeyCredentials = useCallback(
|
||||
async (
|
||||
provider: CredentialsProviderName,
|
||||
credentials: APIKeyCredentialsCreatable,
|
||||
): Promise<CredentialsMetaResponse> => {
|
||||
const credsMeta = await api.createAPIKeyCredentials({
|
||||
provider,
|
||||
...credentials,
|
||||
});
|
||||
addCredentials(provider, credsMeta);
|
||||
return credsMeta;
|
||||
},
|
||||
[api, addCredentials],
|
||||
);
|
||||
|
||||
useEffect(() => {
|
||||
api.isAuthenticated().then((isAuthenticated) => {
|
||||
if (!isAuthenticated) return;
|
||||
|
||||
CREDENTIALS_PROVIDER_NAMES.forEach((provider) => {
|
||||
api.listCredentials(provider).then((response) => {
|
||||
const { oauthCreds, apiKeys } = response.reduce<{
|
||||
oauthCreds: CredentialsMetaResponse[];
|
||||
apiKeys: CredentialsMetaResponse[];
|
||||
}>(
|
||||
(acc, cred) => {
|
||||
if (cred.type === "oauth2") {
|
||||
acc.oauthCreds.push(cred);
|
||||
} else if (cred.type === "api_key") {
|
||||
acc.apiKeys.push(cred);
|
||||
}
|
||||
return acc;
|
||||
},
|
||||
{ oauthCreds: [], apiKeys: [] },
|
||||
);
|
||||
|
||||
setProviders((prev) => ({
|
||||
...prev,
|
||||
[provider]: {
|
||||
provider,
|
||||
providerName: providerDisplayNames[provider],
|
||||
savedApiKeys: apiKeys,
|
||||
savedOAuthCredentials: oauthCreds,
|
||||
oAuthCallback: (code: string, state_token: string) =>
|
||||
oAuthCallback(provider, code, state_token),
|
||||
createAPIKeyCredentials: (
|
||||
credentials: APIKeyCredentialsCreatable,
|
||||
) => createAPIKeyCredentials(provider, credentials),
|
||||
},
|
||||
}));
|
||||
});
|
||||
});
|
||||
});
|
||||
}, [api, createAPIKeyCredentials, oAuthCallback]);
|
||||
|
||||
return (
|
||||
<CredentialsProvidersContext.Provider value={providers}>
|
||||
{children}
|
||||
</CredentialsProvidersContext.Provider>
|
||||
);
|
||||
}
|
||||
@@ -9,6 +9,7 @@ import {
|
||||
BlockIOStringSubSchema,
|
||||
BlockIONumberSubSchema,
|
||||
BlockIOBooleanSubSchema,
|
||||
BlockIOCredentialsSubSchema,
|
||||
} from "@/lib/autogpt-server-api/types";
|
||||
import React, { FC, useCallback, useEffect, useState } from "react";
|
||||
import { Button } from "./ui/button";
|
||||
@@ -23,6 +24,7 @@ import {
|
||||
import { Input } from "./ui/input";
|
||||
import NodeHandle from "./NodeHandle";
|
||||
import { ConnectionData } from "./CustomNode";
|
||||
import { CredentialsInput } from "./integrations/credentials-input";
|
||||
|
||||
type NodeObjectInputTreeProps = {
|
||||
selfKey?: string;
|
||||
@@ -114,6 +116,18 @@ export const NodeGenericInputField: FC<{
|
||||
console.warn(`Unsupported 'allOf' in schema for '${propKey}'!`, propSchema);
|
||||
}
|
||||
|
||||
if ("credentials_provider" in propSchema) {
|
||||
return (
|
||||
<NodeCredentialsInput
|
||||
selfKey={propKey}
|
||||
value={currentValue}
|
||||
errors={errors}
|
||||
className={className}
|
||||
handleInputChange={handleInputChange}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
if ("properties" in propSchema) {
|
||||
return (
|
||||
<NodeObjectInputTree
|
||||
@@ -277,6 +291,28 @@ export const NodeGenericInputField: FC<{
|
||||
}
|
||||
};
|
||||
|
||||
const NodeCredentialsInput: FC<{
|
||||
selfKey: string;
|
||||
value: any;
|
||||
errors: { [key: string]: string | undefined };
|
||||
handleInputChange: NodeObjectInputTreeProps["handleInputChange"];
|
||||
className?: string;
|
||||
}> = ({ selfKey, value, errors, handleInputChange, className }) => {
|
||||
return (
|
||||
<div className={cn("flex flex-col", className)}>
|
||||
<CredentialsInput
|
||||
onSelectCredentials={(credsMeta) =>
|
||||
handleInputChange(selfKey, credsMeta)
|
||||
}
|
||||
selectedCredentials={value}
|
||||
/>
|
||||
{errors[selfKey] && (
|
||||
<span className="error-message">{errors[selfKey]}</span>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
const NodeKeyValueInput: FC<{
|
||||
selfKey: string;
|
||||
schema: BlockIOKVSubSchema;
|
||||
|
||||
@@ -575,4 +575,148 @@ export const IconMegaphone = createIcon((props) => (
|
||||
</svg>
|
||||
));
|
||||
|
||||
/**
|
||||
* Key icon component.
|
||||
*
|
||||
* @component IconKey
|
||||
* @param {IconProps} props - The props object containing additional attributes and event handlers for the icon.
|
||||
* @returns {JSX.Element} - The key icon.
|
||||
*
|
||||
* @example
|
||||
* // Default usage
|
||||
* <IconKey />
|
||||
*
|
||||
* @example
|
||||
* // With custom color and size
|
||||
* <IconKey className="text-primary" size="lg" />
|
||||
*
|
||||
* @example
|
||||
* // With custom size and onClick handler
|
||||
* <IconKey size="sm" onClick={handleOnClick} />
|
||||
*/
|
||||
export const IconKey = createIcon((props) => (
|
||||
<svg
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
viewBox="0 0 24 24"
|
||||
fill="none"
|
||||
stroke="currentColor"
|
||||
strokeWidth="2"
|
||||
strokeLinecap="round"
|
||||
strokeLinejoin="round"
|
||||
{...props}
|
||||
>
|
||||
<path d="M2.586 17.414A2 2 0 0 0 2 18.828V21a1 1 0 0 0 1 1h3a1 1 0 0 0 1-1v-1a1 1 0 0 1 1-1h1a1 1 0 0 0 1-1v-1a1 1 0 0 1 1-1h.172a2 2 0 0 0 1.414-.586l.814-.814a6.5 6.5 0 1 0-4-4z" />
|
||||
<circle cx="16.5" cy="7.5" r=".5" fill="currentColor" />
|
||||
</svg>
|
||||
));
|
||||
|
||||
/**
|
||||
* Key(+) icon component.
|
||||
*
|
||||
* @component IconKeyPlus
|
||||
* @param {IconProps} props - The props object containing additional attributes and event handlers for the icon.
|
||||
* @returns {JSX.Element} - The key(+) icon.
|
||||
*
|
||||
* @example
|
||||
* // Default usage
|
||||
* <IconKeyPlus />
|
||||
*
|
||||
* @example
|
||||
* // With custom color and size
|
||||
* <IconKeyPlus className="text-primary" size="lg" />
|
||||
*
|
||||
* @example
|
||||
* // With custom size and onClick handler
|
||||
* <IconKeyPlus size="sm" onClick={handleOnClick} />
|
||||
*/
|
||||
export const IconKeyPlus = createIcon((props) => (
|
||||
<svg
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
viewBox="0 0 24 24"
|
||||
fill="none"
|
||||
stroke="currentColor"
|
||||
strokeWidth="2"
|
||||
strokeLinecap="round"
|
||||
strokeLinejoin="round"
|
||||
{...props}
|
||||
>
|
||||
<path d="M2.586 17.414A2 2 0 0 0 2 18.828V21a1 1 0 0 0 1 1h3a1 1 0 0 0 1-1v-1a1 1 0 0 1 1-1h1a1 1 0 0 0 1-1v-1a1 1 0 0 1 1-1h.172a2 2 0 0 0 1.414-.586l.814-.814a6.5 6.5 0 1 0-4-4z" />
|
||||
{/* <circle cx="16.5" cy="7.5" r=".5" fill="currentColor" /> */}
|
||||
<line x1="15.6" x2="15.6" y1="5.4" y2="11.4" />
|
||||
<line x1="12.6" x2="18.6" y1="8.4" y2="8.4" />
|
||||
</svg>
|
||||
));
|
||||
|
||||
/**
|
||||
* User icon component.
|
||||
*
|
||||
* @component IconUser
|
||||
* @param {IconProps} props - The props object containing additional attributes and event handlers for the icon.
|
||||
* @returns {JSX.Element} - The user icon.
|
||||
*
|
||||
* @example
|
||||
* // Default usage
|
||||
* <IconUser />
|
||||
*
|
||||
* @example
|
||||
* // With custom color and size
|
||||
* <IconUser className="text-primary" size="lg" />
|
||||
*
|
||||
* @example
|
||||
* // With custom size and onClick handler
|
||||
* <IconUser size="sm" onClick={handleOnClick} />
|
||||
*/
|
||||
export const IconUser = createIcon((props) => (
|
||||
<svg
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
viewBox="0 0 24 24"
|
||||
fill="none"
|
||||
stroke="currentColor"
|
||||
strokeWidth="2"
|
||||
strokeLinecap="round"
|
||||
strokeLinejoin="round"
|
||||
{...props}
|
||||
>
|
||||
<path d="M19 21v-2a4 4 0 0 0-4-4H9a4 4 0 0 0-4 4v2" />
|
||||
<circle cx="12" cy="7" r="4" />
|
||||
</svg>
|
||||
));
|
||||
|
||||
/**
|
||||
* User(+) icon component.
|
||||
*
|
||||
* @component IconUserPlus
|
||||
* @param {IconProps} props - The props object containing additional attributes and event handlers for the icon.
|
||||
* @returns {JSX.Element} - The user plus icon.
|
||||
*
|
||||
* @example
|
||||
* // Default usage
|
||||
* <IconUserPlus />
|
||||
*
|
||||
* @example
|
||||
* // With custom color and size
|
||||
* <IconUserPlus className="text-primary" size="lg" />
|
||||
*
|
||||
* @example
|
||||
* // With custom size and onClick handler
|
||||
* <IconUserPlus size="sm" onClick={handleOnClick} />
|
||||
*/
|
||||
export const IconUserPlus = createIcon((props) => (
|
||||
<svg
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
viewBox="0 0 24 24"
|
||||
fill="none"
|
||||
stroke="currentColor"
|
||||
strokeWidth="2"
|
||||
strokeLinecap="round"
|
||||
strokeLinejoin="round"
|
||||
{...props}
|
||||
>
|
||||
<path d="M16 21v-2a4 4 0 0 0-4-4H6a4 4 0 0 0-4 4v2" />
|
||||
<circle cx="9" cy="7" r="4" />
|
||||
<line x1="19" x2="19" y1="8" y2="14" />
|
||||
<line x1="22" x2="16" y1="11" y2="11" />
|
||||
</svg>
|
||||
));
|
||||
|
||||
export { iconVariants };
|
||||
|
||||
77
autogpt_platform/frontend/src/hooks/useCredentials.ts
Normal file
77
autogpt_platform/frontend/src/hooks/useCredentials.ts
Normal file
@@ -0,0 +1,77 @@
|
||||
import { useContext } from "react";
|
||||
import { CustomNodeData } from "@/components/CustomNode";
|
||||
import { BlockIOCredentialsSubSchema } from "@/lib/autogpt-server-api";
|
||||
import { Node, useNodeId, useNodesData } from "@xyflow/react";
|
||||
import {
|
||||
CredentialsProviderData,
|
||||
CredentialsProvidersContext,
|
||||
} from "@/components/integrations/credentials-provider";
|
||||
|
||||
export type CredentialsData =
|
||||
| {
|
||||
provider: string;
|
||||
schema: BlockIOCredentialsSubSchema;
|
||||
supportsApiKey: boolean;
|
||||
supportsOAuth2: boolean;
|
||||
isLoading: true;
|
||||
}
|
||||
| (CredentialsProviderData & {
|
||||
schema: BlockIOCredentialsSubSchema;
|
||||
supportsApiKey: boolean;
|
||||
supportsOAuth2: boolean;
|
||||
isLoading: false;
|
||||
});
|
||||
|
||||
export default function useCredentials(): CredentialsData | null {
|
||||
const nodeId = useNodeId();
|
||||
const allProviders = useContext(CredentialsProvidersContext);
|
||||
|
||||
if (!nodeId) {
|
||||
throw new Error("useCredentials must be within a CustomNode");
|
||||
}
|
||||
|
||||
const data = useNodesData<Node<CustomNodeData>>(nodeId)!.data;
|
||||
const credentialsSchema = data.inputSchema.properties
|
||||
.credentials as BlockIOCredentialsSubSchema;
|
||||
|
||||
// If block input schema doesn't have credentials, return null
|
||||
if (!credentialsSchema) {
|
||||
return null;
|
||||
}
|
||||
|
||||
const provider = allProviders
|
||||
? allProviders[credentialsSchema?.credentials_provider]
|
||||
: null;
|
||||
|
||||
const supportsApiKey =
|
||||
credentialsSchema.credentials_types.includes("api_key");
|
||||
const supportsOAuth2 = credentialsSchema.credentials_types.includes("oauth2");
|
||||
|
||||
// No provider means maybe it's still loading
|
||||
if (!provider) {
|
||||
return {
|
||||
provider: credentialsSchema.credentials_provider,
|
||||
schema: credentialsSchema,
|
||||
supportsApiKey,
|
||||
supportsOAuth2,
|
||||
isLoading: true,
|
||||
};
|
||||
}
|
||||
|
||||
// Filter by OAuth credentials that have sufficient scopes for this block
|
||||
const requiredScopes = credentialsSchema.credentials_scopes;
|
||||
const savedOAuthCredentials = requiredScopes
|
||||
? provider.savedOAuthCredentials.filter((c) =>
|
||||
new Set(c.scopes).isSupersetOf(new Set(requiredScopes)),
|
||||
)
|
||||
: provider.savedOAuthCredentials;
|
||||
|
||||
return {
|
||||
...provider,
|
||||
schema: credentialsSchema,
|
||||
supportsApiKey,
|
||||
supportsOAuth2,
|
||||
savedOAuthCredentials,
|
||||
isLoading: false,
|
||||
};
|
||||
}
|
||||
@@ -1,6 +1,10 @@
|
||||
import { SupabaseClient } from "@supabase/supabase-js";
|
||||
import {
|
||||
AnalyticsMetrics,
|
||||
AnalyticsDetails,
|
||||
APIKeyCredentials,
|
||||
Block,
|
||||
CredentialsMetaResponse,
|
||||
Graph,
|
||||
GraphCreatable,
|
||||
GraphUpdateable,
|
||||
@@ -9,9 +13,8 @@ import {
|
||||
GraphExecuteResponse,
|
||||
ExecutionMeta,
|
||||
NodeExecutionResult,
|
||||
OAuth2Credentials,
|
||||
User,
|
||||
AnalyticsMetrics,
|
||||
AnalyticsDetails,
|
||||
} from "./types";
|
||||
|
||||
export default class BaseAutoGPTServerAPI {
|
||||
@@ -34,6 +37,14 @@ export default class BaseAutoGPTServerAPI {
|
||||
this.supabaseClient = supabaseClient;
|
||||
}
|
||||
|
||||
async isAuthenticated(): Promise<boolean> {
|
||||
if (!this.supabaseClient) return false;
|
||||
const {
|
||||
data: { session },
|
||||
} = await this.supabaseClient?.auth.getSession();
|
||||
return session != null;
|
||||
}
|
||||
|
||||
async createUser(): Promise<User> {
|
||||
return this._request("POST", "/auth/user", {});
|
||||
}
|
||||
@@ -156,6 +167,53 @@ export default class BaseAutoGPTServerAPI {
|
||||
).map(parseNodeExecutionResultTimestamps);
|
||||
}
|
||||
|
||||
async oAuthLogin(
|
||||
provider: string,
|
||||
scopes?: string[],
|
||||
): Promise<{ login_url: string; state_token: string }> {
|
||||
const query = scopes ? { scopes: scopes.join(",") } : undefined;
|
||||
return await this._get(`/integrations/${provider}/login`, query);
|
||||
}
|
||||
|
||||
async oAuthCallback(
|
||||
provider: string,
|
||||
code: string,
|
||||
state_token: string,
|
||||
): Promise<CredentialsMetaResponse> {
|
||||
return this._request("POST", `/integrations/${provider}/callback`, {
|
||||
code,
|
||||
state_token,
|
||||
});
|
||||
}
|
||||
|
||||
async createAPIKeyCredentials(
|
||||
credentials: Omit<APIKeyCredentials, "id" | "type">,
|
||||
): Promise<APIKeyCredentials> {
|
||||
return this._request(
|
||||
"POST",
|
||||
`/integrations/${credentials.provider}/credentials`,
|
||||
credentials,
|
||||
);
|
||||
}
|
||||
|
||||
async listCredentials(provider: string): Promise<CredentialsMetaResponse[]> {
|
||||
return this._get(`/integrations/${provider}/credentials`);
|
||||
}
|
||||
|
||||
async getCredentials(
|
||||
provider: string,
|
||||
id: string,
|
||||
): Promise<APIKeyCredentials | OAuth2Credentials> {
|
||||
return this._get(`/integrations/${provider}/credentials/${id}`);
|
||||
}
|
||||
|
||||
async deleteCredentials(provider: string, id: string): Promise<void> {
|
||||
return this._request(
|
||||
"DELETE",
|
||||
`/integrations/${provider}/credentials/${id}`,
|
||||
);
|
||||
}
|
||||
|
||||
async logMetric(metric: AnalyticsMetrics) {
|
||||
return this._request("POST", "/analytics/log_raw_metric", metric);
|
||||
}
|
||||
@@ -164,14 +222,14 @@ export default class BaseAutoGPTServerAPI {
|
||||
return this._request("POST", "/analytics/log_raw_analytics", analytic);
|
||||
}
|
||||
|
||||
private async _get(path: string) {
|
||||
return this._request("GET", path);
|
||||
private async _get(path: string, query?: Record<string, any>) {
|
||||
return this._request("GET", path, query);
|
||||
}
|
||||
|
||||
private async _request(
|
||||
method: "GET" | "POST" | "PUT" | "PATCH",
|
||||
method: "GET" | "POST" | "PUT" | "PATCH" | "DELETE",
|
||||
path: string,
|
||||
payload?: { [key: string]: any },
|
||||
payload?: Record<string, any>,
|
||||
) {
|
||||
if (method != "GET") {
|
||||
console.debug(`${method} ${path} payload:`, payload);
|
||||
@@ -181,18 +239,25 @@ export default class BaseAutoGPTServerAPI {
|
||||
(await this.supabaseClient?.auth.getSession())?.data.session
|
||||
?.access_token || "";
|
||||
|
||||
const response = await fetch(this.baseUrl + path, {
|
||||
let url = this.baseUrl + path;
|
||||
if (method === "GET" && payload) {
|
||||
// For GET requests, use payload as query
|
||||
const queryParams = new URLSearchParams(payload);
|
||||
url += `?${queryParams.toString()}`;
|
||||
}
|
||||
|
||||
const hasRequestBody = method !== "GET" && payload !== undefined;
|
||||
const response = await fetch(url, {
|
||||
method,
|
||||
headers:
|
||||
method != "GET"
|
||||
? {
|
||||
"Content-Type": "application/json",
|
||||
Authorization: token ? `Bearer ${token}` : "",
|
||||
}
|
||||
: {
|
||||
Authorization: token ? `Bearer ${token}` : "",
|
||||
},
|
||||
body: JSON.stringify(payload),
|
||||
headers: hasRequestBody
|
||||
? {
|
||||
"Content-Type": "application/json",
|
||||
Authorization: token ? `Bearer ${token}` : "",
|
||||
}
|
||||
: {
|
||||
Authorization: token ? `Bearer ${token}` : "",
|
||||
},
|
||||
body: hasRequestBody ? JSON.stringify(payload) : undefined,
|
||||
});
|
||||
const response_data = await response.json();
|
||||
|
||||
|
||||
@@ -41,6 +41,7 @@ export type BlockIOSubSchema =
|
||||
|
||||
type BlockIOSimpleTypeSubSchema =
|
||||
| BlockIOObjectSubSchema
|
||||
| BlockIOCredentialsSubSchema
|
||||
| BlockIOKVSubSchema
|
||||
| BlockIOArraySubSchema
|
||||
| BlockIOStringSubSchema
|
||||
@@ -91,6 +92,14 @@ export type BlockIOBooleanSubSchema = BlockIOSubSchemaMeta & {
|
||||
default?: boolean;
|
||||
};
|
||||
|
||||
export type CredentialsType = "api_key" | "oauth2";
|
||||
|
||||
export type BlockIOCredentialsSubSchema = BlockIOSubSchemaMeta & {
|
||||
credentials_provider: "github" | "google" | "notion";
|
||||
credentials_scopes?: string[];
|
||||
credentials_types: Array<CredentialsType>;
|
||||
};
|
||||
|
||||
export type BlockIONullSubSchema = BlockIOSubSchemaMeta & {
|
||||
type: "null";
|
||||
};
|
||||
@@ -205,6 +214,51 @@ export type NodeExecutionResult = {
|
||||
end_time?: Date;
|
||||
};
|
||||
|
||||
/* Mirror of backend/server/integrations.py:CredentialsMetaResponse */
|
||||
export type CredentialsMetaResponse = {
|
||||
id: string;
|
||||
type: CredentialsType;
|
||||
title?: string;
|
||||
scopes?: Array<string>;
|
||||
username?: string;
|
||||
};
|
||||
|
||||
/* Mirror of backend/data/model.py:CredentialsMetaInput */
|
||||
export type CredentialsMetaInput = {
|
||||
id: string;
|
||||
type: CredentialsType;
|
||||
title?: string;
|
||||
provider: string;
|
||||
};
|
||||
|
||||
/* Mirror of autogpt_libs/supabase_integration_credentials_store/types.py:_BaseCredentials */
|
||||
type BaseCredentials = {
|
||||
id: string;
|
||||
type: CredentialsType;
|
||||
title?: string;
|
||||
provider: string;
|
||||
};
|
||||
|
||||
/* Mirror of autogpt_libs/supabase_integration_credentials_store/types.py:OAuth2Credentials */
|
||||
export type OAuth2Credentials = BaseCredentials & {
|
||||
type: "oauth2";
|
||||
scopes: string[];
|
||||
username?: string;
|
||||
access_token: string;
|
||||
access_token_expires_at?: number;
|
||||
refresh_token?: string;
|
||||
refresh_token_expires_at?: number;
|
||||
metadata: Record<string, any>;
|
||||
};
|
||||
|
||||
/* Mirror of autogpt_libs/supabase_integration_credentials_store/types.py:APIKeyCredentials */
|
||||
export type APIKeyCredentials = BaseCredentials & {
|
||||
type: "api_key";
|
||||
title: string;
|
||||
api_key: string;
|
||||
expires_at?: number;
|
||||
};
|
||||
|
||||
export type User = {
|
||||
id: string;
|
||||
email: string;
|
||||
|
||||
@@ -86,3 +86,9 @@ env:
|
||||
REDIS_HOST: "redis-dev-master.redis-dev.svc.cluster.local"
|
||||
REDIS_PORT: "6379"
|
||||
BACKEND_CORS_ALLOW_ORIGINS: ["https://dev-builder.agpt.co"]
|
||||
SUPABASE_SERVICE_ROLE_KEY: ""
|
||||
GITHUB_CLIENT_ID: ""
|
||||
GITHUB_CLIENT_SECRET: ""
|
||||
FRONTEND_BASE_URL: ""
|
||||
SUPABASE_URL: ""
|
||||
SUPABASE_JWT_SECRET: ""
|
||||
|
||||
@@ -84,7 +84,7 @@ Follow these steps to create and test a new block:
|
||||
5. **Implement the `run` method with error handling:**, this should contain the main logic of the block:
|
||||
|
||||
```python
|
||||
def run(self, input_data: Input) -> BlockOutput:
|
||||
def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
try:
|
||||
topic = input_data.topic
|
||||
url = f"https://en.wikipedia.org/api/rest_v1/page/summary/{topic}"
|
||||
@@ -105,6 +105,145 @@ Follow these steps to create and test a new block:
|
||||
- **Error handling**: Handle various exceptions that might occur during the API request and data processing.
|
||||
- **Yield**: Use `yield` to output the results.
|
||||
|
||||
### Blocks with authentication
|
||||
|
||||
Our system supports auth offloading for API keys and OAuth2 authorization flows.
|
||||
Adding a block with API key authentication is straight-forward, as is adding a block
|
||||
for a service that we already have OAuth2 support for.
|
||||
|
||||
Implementing the block itself is relatively simple. On top of the instructions above,
|
||||
you're going to add a `credentials` parameter to the `Input` model and the `run` method:
|
||||
```python
|
||||
from autogpt_libs.supabase_integration_credentials_store.types import (
|
||||
APIKeyCredentials,
|
||||
OAuth2Credentials,
|
||||
Credentials,
|
||||
)
|
||||
|
||||
from backend.data.block import Block, BlockOutput, BlockSchema
|
||||
from backend.data.model import CredentialsField
|
||||
|
||||
|
||||
# API Key auth:
|
||||
class BlockWithAPIKeyAuth(Block):
|
||||
class Input(BlockSchema):
|
||||
credentials = CredentialsField(
|
||||
provider="github",
|
||||
supported_credential_types={"api_key"},
|
||||
required_scopes={"repo"},
|
||||
description="The GitHub integration can be used with "
|
||||
"any API key with sufficient permissions for the blocks it is used on.",
|
||||
)
|
||||
|
||||
# ...
|
||||
|
||||
def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: APIKeyCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
...
|
||||
|
||||
# OAuth:
|
||||
class BlockWithOAuth(Block):
|
||||
class Input(BlockSchema):
|
||||
credentials = CredentialsField(
|
||||
provider="github",
|
||||
supported_credential_types={"oauth2"},
|
||||
required_scopes={"repo"},
|
||||
description="The GitHub integration can be used with OAuth.",
|
||||
)
|
||||
|
||||
# ...
|
||||
|
||||
def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: OAuth2Credentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
...
|
||||
|
||||
# API Key auth + OAuth:
|
||||
class BlockWithAPIKeyAndOAuth(Block):
|
||||
class Input(BlockSchema):
|
||||
credentials = CredentialsField(
|
||||
provider="github",
|
||||
supported_credential_types={"api_key", "oauth2"},
|
||||
required_scopes={"repo"},
|
||||
description="The GitHub integration can be used with OAuth, "
|
||||
"or any API key with sufficient permissions for the blocks it is used on.",
|
||||
)
|
||||
|
||||
# ...
|
||||
|
||||
def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: Credentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
...
|
||||
```
|
||||
The credentials will be automagically injected by the executor in the back end.
|
||||
|
||||
The `APIKeyCredentials` and `OAuth2Credentials` models are defined [here](https://github.com/Significant-Gravitas/AutoGPT/blob/master/rnd/autogpt_libs/autogpt_libs/supabase_integration_credentials_store/types.py).
|
||||
To use them in e.g. an API request, you can either access the token directly:
|
||||
```python
|
||||
# credentials: APIKeyCredentials
|
||||
response = requests.post(
|
||||
url,
|
||||
headers={
|
||||
"Authorization": f"Bearer {credentials.api_key.get_secret_value()})",
|
||||
},
|
||||
)
|
||||
|
||||
# credentials: OAuth2Credentials
|
||||
response = requests.post(
|
||||
url,
|
||||
headers={
|
||||
"Authorization": f"Bearer {credentials.access_token.get_secret_value()})",
|
||||
},
|
||||
)
|
||||
```
|
||||
or use the shortcut `credentials.bearer()`:
|
||||
```python
|
||||
# credentials: APIKeyCredentials | OAuth2Credentials
|
||||
response = requests.post(
|
||||
url,
|
||||
headers={"Authorization": credentials.bearer()},
|
||||
)
|
||||
```
|
||||
|
||||
#### Adding an OAuth2 service integration
|
||||
|
||||
To add support for a new OAuth2-authenticated service, you'll need to add an `OAuthHandler`.
|
||||
All our existing handlers and the base class can be found [here][OAuth2 handlers].
|
||||
|
||||
Every handler must implement the following parts of the [`BaseOAuthHandler`] interface:
|
||||
- `PROVIDER_NAME`
|
||||
- `__init__(client_id, client_secret, redirect_uri)`
|
||||
- `get_login_url(scopes, state)`
|
||||
- `exchange_code_for_tokens(code)`
|
||||
- `_refresh_tokens(credentials)`
|
||||
|
||||
As you can see, this is modeled after the standard OAuth2 flow.
|
||||
|
||||
Aside from implementing the `OAuthHandler` itself, adding a handler into the system requires two more things:
|
||||
- Adding the handler class to `HANDLERS_BY_NAME` [here](https://github.com/Significant-Gravitas/AutoGPT/blob/master/autogpt_platform/backend/backend/integrations/oauth/__init__.py)
|
||||
- Adding `{provider}_client_id` and `{provider}_client_secret` to the application's `Secrets` [here](https://github.com/Significant-Gravitas/AutoGPT/blob/e3f35d79c7e9fc6ee0cabefcb73e0fad15a0ce2d/autogpt_platform/backend/backend/util/settings.py#L132)
|
||||
|
||||
[OAuth2 handlers]: https://github.com/Significant-Gravitas/AutoGPT/tree/master/autogpt_platform/backend/backend/integrations/oauth
|
||||
[`BaseOAuthHandler`]: https://github.com/Significant-Gravitas/AutoGPT/blob/master/autogpt_platform/backend/backend/integrations/oauth/base.py
|
||||
|
||||
#### Example: GitHub integration
|
||||
- GitHub blocks with API key + OAuth2 support: [`blocks/github`](https://github.com/Significant-Gravitas/AutoGPT/tree/master/autogpt_platform/backend/backend/blocks/github/)
|
||||
- GitHub OAuth2 handler: [`integrations/oauth/github.py`](https://github.com/Significant-Gravitas/AutoGPT/blob/master/autogpt_platform/backend/backend/integrations/oauth/github.py)
|
||||
|
||||
## Key Points to Remember
|
||||
|
||||
- **Unique ID**: Give your block a unique ID in the **init** method.
|
||||
@@ -117,7 +256,8 @@ Follow these steps to create and test a new block:
|
||||
|
||||
The testing of blocks is handled by `test_block.py`, which does the following:
|
||||
|
||||
1. It calls the block with the provided `test_input`.
|
||||
1. It calls the block with the provided `test_input`.
|
||||
If the block has a `credentials` field, `test_credentials` is passed in as well.
|
||||
2. If a `test_mock` is provided, it temporarily replaces the specified methods with the mock functions.
|
||||
3. It then asserts that the output matches the `test_output`.
|
||||
|
||||
|
||||
Reference in New Issue
Block a user