mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-04-08 03:00:28 -04:00
feat(blocks): add Replicate model blocks (#10337)
This PR adds two new blocks for running Replicate models in AutoGPT:
ReplicateModelBlock: Synchronous execution of any Replicate model with
custom inputs
Key Features:
Support for any public Replicate model via model name (e.g.,
"stability-ai/stable-diffusion")
Custom input parameters via dictionary (e.g., {"prompt": "a beautiful
landscape"})
Optional model version specification for reproducible results
Proper credentials handling using ProviderName.REPLICATE
Comprehensive test suite with 12 test cases covering success, error, and
edge cases
Type-safe implementation with full pyright compliance
Mock methods for testing and development
# Checklist 📋
## For code changes:
- [X] I have clearly listed my changes in the PR description
- [X] I have made a test plan
- [X] I have tested my changes according to the test plan:
- Unit tests for ReplicateModelBlock (6 test cases)
- Test block initialization and configuration
- Test mock methods for development
- Test error handling and edge cases
- Verify type safety with pyright (0 errors)
- Verify code formatting with Black and isort
- Verify linting with Ruff (0 errors)
- Test credentials handling with ProviderName.REPLICATE
- Test model version specification functionality
- Test both synchronous and asynchronous execution paths
---------
Co-authored-by: Nicholas Tindle <nicholas.tindle@agpt.co>
This commit is contained in:
24
autogpt_platform/backend/backend/blocks/replicate/_auth.py
Normal file
24
autogpt_platform/backend/backend/blocks/replicate/_auth.py
Normal file
@@ -0,0 +1,24 @@
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import SecretStr
|
||||
|
||||
from backend.data.model import APIKeyCredentials, CredentialsMetaInput, ProviderName
|
||||
|
||||
TEST_CREDENTIALS = APIKeyCredentials(
|
||||
id="01234567-89ab-cdef-0123-456789abcdef",
|
||||
provider="replicate",
|
||||
api_key=SecretStr("mock-replicate-api-key"),
|
||||
title="Mock Replicate API key",
|
||||
expires_at=None,
|
||||
)
|
||||
TEST_CREDENTIALS_INPUT = {
|
||||
"provider": TEST_CREDENTIALS.provider,
|
||||
"id": TEST_CREDENTIALS.id,
|
||||
"type": TEST_CREDENTIALS.type,
|
||||
"title": TEST_CREDENTIALS.type,
|
||||
}
|
||||
|
||||
ReplicateCredentials = APIKeyCredentials
|
||||
ReplicateCredentialsInput = CredentialsMetaInput[
|
||||
Literal[ProviderName.REPLICATE], Literal["api_key"]
|
||||
]
|
||||
39
autogpt_platform/backend/backend/blocks/replicate/_helper.py
Normal file
39
autogpt_platform/backend/backend/blocks/replicate/_helper.py
Normal file
@@ -0,0 +1,39 @@
|
||||
import logging
|
||||
|
||||
from replicate.helpers import FileOutput
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
ReplicateOutputs = FileOutput | list[FileOutput] | list[str] | str | list[dict]
|
||||
|
||||
|
||||
def extract_result(output: ReplicateOutputs) -> str:
|
||||
result = (
|
||||
"Unable to process result. Please contact us with the models and inputs used"
|
||||
)
|
||||
# Check if output is a list or a string and extract accordingly; otherwise, assign a default message
|
||||
if isinstance(output, list) and len(output) > 0:
|
||||
# we could use something like all(output, FileOutput) but it will be slower so we just type ignore
|
||||
if isinstance(output[0], FileOutput):
|
||||
result = output[0].url # If output is a list, get the first element
|
||||
elif isinstance(output[0], str):
|
||||
result = "".join(
|
||||
output # type: ignore we're already not a file output here
|
||||
) # type:ignore If output is a list and a str, join the elements the first element. Happens if its text
|
||||
elif isinstance(output[0], dict):
|
||||
result = str(output[0])
|
||||
else:
|
||||
logger.error(
|
||||
"Replicate generated a new output type that's not a file output or a str in a replicate block"
|
||||
)
|
||||
elif isinstance(output, FileOutput):
|
||||
result = output.url # If output is a FileOutput, use the url
|
||||
elif isinstance(output, str):
|
||||
result = output # If output is a string (for some reason due to their janky type hinting), use it directly
|
||||
else:
|
||||
result = "No output received" # Fallback message if output is not as expected
|
||||
logger.error(
|
||||
"We somehow didn't get an output from a replicate block. This is almost certainly an error"
|
||||
)
|
||||
|
||||
return result
|
||||
@@ -1,33 +1,17 @@
|
||||
import os
|
||||
from enum import Enum
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import SecretStr
|
||||
from replicate.client import Client as ReplicateClient
|
||||
from replicate.helpers import FileOutput
|
||||
|
||||
from backend.blocks.replicate._auth import (
|
||||
TEST_CREDENTIALS,
|
||||
TEST_CREDENTIALS_INPUT,
|
||||
ReplicateCredentialsInput,
|
||||
)
|
||||
from backend.blocks.replicate._helper import ReplicateOutputs, extract_result
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import (
|
||||
APIKeyCredentials,
|
||||
CredentialsField,
|
||||
CredentialsMetaInput,
|
||||
SchemaField,
|
||||
)
|
||||
from backend.integrations.providers import ProviderName
|
||||
|
||||
TEST_CREDENTIALS = APIKeyCredentials(
|
||||
id="01234567-89ab-cdef-0123-456789abcdef",
|
||||
provider="replicate",
|
||||
api_key=SecretStr("mock-replicate-api-key"),
|
||||
title="Mock Replicate API key",
|
||||
expires_at=None,
|
||||
)
|
||||
TEST_CREDENTIALS_INPUT = {
|
||||
"provider": TEST_CREDENTIALS.provider,
|
||||
"id": TEST_CREDENTIALS.id,
|
||||
"type": TEST_CREDENTIALS.type,
|
||||
"title": TEST_CREDENTIALS.type,
|
||||
}
|
||||
from backend.data.model import APIKeyCredentials, CredentialsField, SchemaField
|
||||
|
||||
|
||||
# Model name enum
|
||||
@@ -55,9 +39,7 @@ class ImageType(str, Enum):
|
||||
|
||||
class ReplicateFluxAdvancedModelBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
credentials: CredentialsMetaInput[
|
||||
Literal[ProviderName.REPLICATE], Literal["api_key"]
|
||||
] = CredentialsField(
|
||||
credentials: ReplicateCredentialsInput = CredentialsField(
|
||||
description="The Replicate integration can be used with "
|
||||
"any API key with sufficient permissions for the blocks it is used on.",
|
||||
)
|
||||
@@ -201,7 +183,7 @@ class ReplicateFluxAdvancedModelBlock(Block):
|
||||
client = ReplicateClient(api_token=api_key.get_secret_value())
|
||||
|
||||
# Run the model with additional parameters
|
||||
output: FileOutput | list[FileOutput] = await client.async_run( # type: ignore This is because they changed the return type, and didn't update the type hint! It should be overloaded depending on the value of `use_file_output` to `FileOutput | list[FileOutput]` but it's `Any | Iterator[Any]`
|
||||
output: ReplicateOutputs = await client.async_run( # type: ignore This is because they changed the return type, and didn't update the type hint! It should be overloaded depending on the value of `use_file_output` to `FileOutput | list[FileOutput]` but it's `Any | Iterator[Any]`
|
||||
f"{model_name}",
|
||||
input={
|
||||
"prompt": prompt,
|
||||
@@ -217,21 +199,6 @@ class ReplicateFluxAdvancedModelBlock(Block):
|
||||
wait=False, # don't arbitrarily return data:octect/stream or sometimes url depending on the model???? what is this api
|
||||
)
|
||||
|
||||
# Check if output is a list or a string and extract accordingly; otherwise, assign a default message
|
||||
if isinstance(output, list) and len(output) > 0:
|
||||
if isinstance(output[0], FileOutput):
|
||||
result_url = output[0].url # If output is a list, get the first element
|
||||
else:
|
||||
result_url = output[
|
||||
0
|
||||
] # If output is a list and not a FileOutput, get the first element. Should never happen, but just in case.
|
||||
elif isinstance(output, FileOutput):
|
||||
result_url = output.url # If output is a FileOutput, use the url
|
||||
elif isinstance(output, str):
|
||||
result_url = output # If output is a string (for some reason due to their janky type hinting), use it directly
|
||||
else:
|
||||
result_url = (
|
||||
"No output received" # Fallback message if output is not as expected
|
||||
)
|
||||
result = extract_result(output)
|
||||
|
||||
return result_url
|
||||
return result
|
||||
@@ -0,0 +1,133 @@
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import SecretStr
|
||||
from replicate.client import Client as ReplicateClient
|
||||
|
||||
from backend.blocks.replicate._auth import (
|
||||
TEST_CREDENTIALS,
|
||||
TEST_CREDENTIALS_INPUT,
|
||||
ReplicateCredentialsInput,
|
||||
)
|
||||
from backend.blocks.replicate._helper import ReplicateOutputs, extract_result
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import APIKeyCredentials, CredentialsField, SchemaField
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ReplicateModelBlock(Block):
|
||||
"""
|
||||
Block for running any Replicate model with custom inputs.
|
||||
|
||||
This block allows you to:
|
||||
- Use any public Replicate model
|
||||
- Pass custom inputs as a dictionary
|
||||
- Specify model versions
|
||||
- Get structured outputs with prediction metadata
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: ReplicateCredentialsInput = CredentialsField(
|
||||
description="Enter your Replicate API key to access the model API. You can obtain an API key from https://replicate.com/account/api-tokens.",
|
||||
)
|
||||
model_name: str = SchemaField(
|
||||
description="The Replicate model name (format: 'owner/model-name')",
|
||||
placeholder="stability-ai/stable-diffusion",
|
||||
advanced=False,
|
||||
)
|
||||
model_inputs: dict[str, str | int] = SchemaField(
|
||||
default={},
|
||||
description="Dictionary of inputs to pass to the model",
|
||||
placeholder='{"prompt": "a beautiful landscape", "num_outputs": 1}',
|
||||
advanced=False,
|
||||
)
|
||||
version: Optional[str] = SchemaField(
|
||||
default=None,
|
||||
description="Specific version hash of the model (optional)",
|
||||
placeholder="db21e45d3f7023abc2a46ee38a23973f6dce16bb082a930b0c49861f96d1e5bf",
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
result: str = SchemaField(description="The output from the Replicate model")
|
||||
status: str = SchemaField(description="Status of the prediction")
|
||||
model_name: str = SchemaField(description="Name of the model used")
|
||||
error: str = SchemaField(description="Error message if any", default="")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="c40d75a2-d0ea-44c9-a4f6-634bb3bdab1a",
|
||||
description="Run Replicate models synchronously",
|
||||
categories={BlockCategory.AI},
|
||||
input_schema=ReplicateModelBlock.Input,
|
||||
output_schema=ReplicateModelBlock.Output,
|
||||
test_input={
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"model_name": "meta/llama-2-7b-chat",
|
||||
"model_inputs": {"prompt": "Hello, world!", "max_new_tokens": 50},
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("result", str),
|
||||
("status", str),
|
||||
("model_name", str),
|
||||
],
|
||||
test_mock={
|
||||
"run_model": lambda model_ref, model_inputs, api_key: (
|
||||
"Mock response from Replicate model"
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
"""
|
||||
Execute the Replicate model with the provided inputs.
|
||||
|
||||
Args:
|
||||
input_data: The input data containing model name and inputs
|
||||
credentials: The API credentials
|
||||
|
||||
Yields:
|
||||
BlockOutput containing the model results and metadata
|
||||
"""
|
||||
try:
|
||||
if input_data.version:
|
||||
model_ref = f"{input_data.model_name}:{input_data.version}"
|
||||
else:
|
||||
model_ref = input_data.model_name
|
||||
logger.debug(f"Running Replicate model: {model_ref}")
|
||||
result = await self.run_model(
|
||||
model_ref, input_data.model_inputs, credentials.api_key
|
||||
)
|
||||
yield "result", result
|
||||
yield "status", "succeeded"
|
||||
yield "model_name", input_data.model_name
|
||||
except Exception as e:
|
||||
error_msg = f"Unexpected error running Replicate model: {str(e)}"
|
||||
logger.error(error_msg)
|
||||
raise RuntimeError(error_msg)
|
||||
|
||||
async def run_model(self, model_ref: str, model_inputs: dict, api_key: SecretStr):
|
||||
"""
|
||||
Run the Replicate model. This method can be mocked for testing.
|
||||
|
||||
Args:
|
||||
model_ref: The model reference (e.g., "owner/model-name:version")
|
||||
model_inputs: The inputs to pass to the model
|
||||
api_key: The Replicate API key as SecretStr
|
||||
|
||||
Returns:
|
||||
Tuple of (result, prediction_id)
|
||||
"""
|
||||
api_key_str = api_key.get_secret_value()
|
||||
client = ReplicateClient(api_token=api_key_str)
|
||||
output: ReplicateOutputs = await client.async_run(
|
||||
model_ref, input=model_inputs, wait=False
|
||||
) # type: ignore they suck at typing
|
||||
|
||||
result = extract_result(output)
|
||||
|
||||
return result
|
||||
@@ -18,7 +18,8 @@ from backend.blocks.llm import (
|
||||
AITextSummarizerBlock,
|
||||
LlmModel,
|
||||
)
|
||||
from backend.blocks.replicate_flux_advanced import ReplicateFluxAdvancedModelBlock
|
||||
from backend.blocks.replicate.flux_advanced import ReplicateFluxAdvancedModelBlock
|
||||
from backend.blocks.replicate.replicate_block import ReplicateModelBlock
|
||||
from backend.blocks.smart_decision_maker import SmartDecisionMakerBlock
|
||||
from backend.blocks.talking_head import CreateTalkingAvatarVideoBlock
|
||||
from backend.blocks.text_to_speech_block import UnrealTextToSpeechBlock
|
||||
@@ -291,6 +292,18 @@ BLOCK_COSTS: dict[Type[Block], list[BlockCost]] = {
|
||||
},
|
||||
)
|
||||
],
|
||||
ReplicateModelBlock: [
|
||||
BlockCost(
|
||||
cost_amount=10,
|
||||
cost_filter={
|
||||
"credentials": {
|
||||
"id": replicate_credentials.id,
|
||||
"provider": replicate_credentials.provider,
|
||||
"type": replicate_credentials.type,
|
||||
}
|
||||
},
|
||||
)
|
||||
],
|
||||
AIImageEditorBlock: [
|
||||
BlockCost(
|
||||
cost_amount=10,
|
||||
|
||||
Reference in New Issue
Block a user