feat(blocks): add Replicate model blocks (#10337)

This PR adds two new blocks for running Replicate models in AutoGPT:

ReplicateModelBlock: Synchronous execution of any Replicate model with
custom inputs

Key Features:

Support for any public Replicate model via model name (e.g.,
"stability-ai/stable-diffusion")
Custom input parameters via dictionary (e.g., {"prompt": "a beautiful
landscape"})
Optional model version specification for reproducible results
Proper credentials handling using ProviderName.REPLICATE
Comprehensive test suite with 12 test cases covering success, error, and
edge cases
Type-safe implementation with full pyright compliance
Mock methods for testing and development

# Checklist 📋
## For code changes:
- [X] I have clearly listed my changes in the PR description
- [X] I have made a test plan
- [X] I have tested my changes according to the test plan:
  - Unit tests for ReplicateModelBlock (6 test cases)
  -  Test block initialization and configuration
  -  Test mock methods for development
  -  Test error handling and edge cases
  -  Verify type safety with pyright (0 errors)
  -  Verify code formatting with Black and isort
  -  Verify linting with Ruff (0 errors)
  -  Test credentials handling with ProviderName.REPLICATE
  -  Test model version specification functionality
  -  Test both synchronous and asynchronous execution paths

---------

Co-authored-by: Nicholas Tindle <nicholas.tindle@agpt.co>
This commit is contained in:
snehaoladri
2025-07-17 23:20:44 -05:00
committed by GitHub
parent d33459ddb5
commit 6641a77c70
5 changed files with 221 additions and 45 deletions

View File

@@ -0,0 +1,24 @@
from typing import Literal
from pydantic import SecretStr
from backend.data.model import APIKeyCredentials, CredentialsMetaInput, ProviderName
TEST_CREDENTIALS = APIKeyCredentials(
id="01234567-89ab-cdef-0123-456789abcdef",
provider="replicate",
api_key=SecretStr("mock-replicate-api-key"),
title="Mock Replicate API key",
expires_at=None,
)
TEST_CREDENTIALS_INPUT = {
"provider": TEST_CREDENTIALS.provider,
"id": TEST_CREDENTIALS.id,
"type": TEST_CREDENTIALS.type,
"title": TEST_CREDENTIALS.type,
}
ReplicateCredentials = APIKeyCredentials
ReplicateCredentialsInput = CredentialsMetaInput[
Literal[ProviderName.REPLICATE], Literal["api_key"]
]

View File

@@ -0,0 +1,39 @@
import logging
from replicate.helpers import FileOutput
logger = logging.getLogger(__name__)
ReplicateOutputs = FileOutput | list[FileOutput] | list[str] | str | list[dict]
def extract_result(output: ReplicateOutputs) -> str:
result = (
"Unable to process result. Please contact us with the models and inputs used"
)
# Check if output is a list or a string and extract accordingly; otherwise, assign a default message
if isinstance(output, list) and len(output) > 0:
# we could use something like all(output, FileOutput) but it will be slower so we just type ignore
if isinstance(output[0], FileOutput):
result = output[0].url # If output is a list, get the first element
elif isinstance(output[0], str):
result = "".join(
output # type: ignore we're already not a file output here
) # type:ignore If output is a list and a str, join the elements the first element. Happens if its text
elif isinstance(output[0], dict):
result = str(output[0])
else:
logger.error(
"Replicate generated a new output type that's not a file output or a str in a replicate block"
)
elif isinstance(output, FileOutput):
result = output.url # If output is a FileOutput, use the url
elif isinstance(output, str):
result = output # If output is a string (for some reason due to their janky type hinting), use it directly
else:
result = "No output received" # Fallback message if output is not as expected
logger.error(
"We somehow didn't get an output from a replicate block. This is almost certainly an error"
)
return result

View File

@@ -1,33 +1,17 @@
import os
from enum import Enum
from typing import Literal
from pydantic import SecretStr
from replicate.client import Client as ReplicateClient
from replicate.helpers import FileOutput
from backend.blocks.replicate._auth import (
TEST_CREDENTIALS,
TEST_CREDENTIALS_INPUT,
ReplicateCredentialsInput,
)
from backend.blocks.replicate._helper import ReplicateOutputs, extract_result
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
from backend.data.model import (
APIKeyCredentials,
CredentialsField,
CredentialsMetaInput,
SchemaField,
)
from backend.integrations.providers import ProviderName
TEST_CREDENTIALS = APIKeyCredentials(
id="01234567-89ab-cdef-0123-456789abcdef",
provider="replicate",
api_key=SecretStr("mock-replicate-api-key"),
title="Mock Replicate API key",
expires_at=None,
)
TEST_CREDENTIALS_INPUT = {
"provider": TEST_CREDENTIALS.provider,
"id": TEST_CREDENTIALS.id,
"type": TEST_CREDENTIALS.type,
"title": TEST_CREDENTIALS.type,
}
from backend.data.model import APIKeyCredentials, CredentialsField, SchemaField
# Model name enum
@@ -55,9 +39,7 @@ class ImageType(str, Enum):
class ReplicateFluxAdvancedModelBlock(Block):
class Input(BlockSchema):
credentials: CredentialsMetaInput[
Literal[ProviderName.REPLICATE], Literal["api_key"]
] = CredentialsField(
credentials: ReplicateCredentialsInput = CredentialsField(
description="The Replicate integration can be used with "
"any API key with sufficient permissions for the blocks it is used on.",
)
@@ -201,7 +183,7 @@ class ReplicateFluxAdvancedModelBlock(Block):
client = ReplicateClient(api_token=api_key.get_secret_value())
# Run the model with additional parameters
output: FileOutput | list[FileOutput] = await client.async_run( # type: ignore This is because they changed the return type, and didn't update the type hint! It should be overloaded depending on the value of `use_file_output` to `FileOutput | list[FileOutput]` but it's `Any | Iterator[Any]`
output: ReplicateOutputs = await client.async_run( # type: ignore This is because they changed the return type, and didn't update the type hint! It should be overloaded depending on the value of `use_file_output` to `FileOutput | list[FileOutput]` but it's `Any | Iterator[Any]`
f"{model_name}",
input={
"prompt": prompt,
@@ -217,21 +199,6 @@ class ReplicateFluxAdvancedModelBlock(Block):
wait=False, # don't arbitrarily return data:octect/stream or sometimes url depending on the model???? what is this api
)
# Check if output is a list or a string and extract accordingly; otherwise, assign a default message
if isinstance(output, list) and len(output) > 0:
if isinstance(output[0], FileOutput):
result_url = output[0].url # If output is a list, get the first element
else:
result_url = output[
0
] # If output is a list and not a FileOutput, get the first element. Should never happen, but just in case.
elif isinstance(output, FileOutput):
result_url = output.url # If output is a FileOutput, use the url
elif isinstance(output, str):
result_url = output # If output is a string (for some reason due to their janky type hinting), use it directly
else:
result_url = (
"No output received" # Fallback message if output is not as expected
)
result = extract_result(output)
return result_url
return result

View File

@@ -0,0 +1,133 @@
import logging
from typing import Optional
from pydantic import SecretStr
from replicate.client import Client as ReplicateClient
from backend.blocks.replicate._auth import (
TEST_CREDENTIALS,
TEST_CREDENTIALS_INPUT,
ReplicateCredentialsInput,
)
from backend.blocks.replicate._helper import ReplicateOutputs, extract_result
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
from backend.data.model import APIKeyCredentials, CredentialsField, SchemaField
logger = logging.getLogger(__name__)
class ReplicateModelBlock(Block):
"""
Block for running any Replicate model with custom inputs.
This block allows you to:
- Use any public Replicate model
- Pass custom inputs as a dictionary
- Specify model versions
- Get structured outputs with prediction metadata
"""
class Input(BlockSchema):
credentials: ReplicateCredentialsInput = CredentialsField(
description="Enter your Replicate API key to access the model API. You can obtain an API key from https://replicate.com/account/api-tokens.",
)
model_name: str = SchemaField(
description="The Replicate model name (format: 'owner/model-name')",
placeholder="stability-ai/stable-diffusion",
advanced=False,
)
model_inputs: dict[str, str | int] = SchemaField(
default={},
description="Dictionary of inputs to pass to the model",
placeholder='{"prompt": "a beautiful landscape", "num_outputs": 1}',
advanced=False,
)
version: Optional[str] = SchemaField(
default=None,
description="Specific version hash of the model (optional)",
placeholder="db21e45d3f7023abc2a46ee38a23973f6dce16bb082a930b0c49861f96d1e5bf",
advanced=True,
)
class Output(BlockSchema):
result: str = SchemaField(description="The output from the Replicate model")
status: str = SchemaField(description="Status of the prediction")
model_name: str = SchemaField(description="Name of the model used")
error: str = SchemaField(description="Error message if any", default="")
def __init__(self):
super().__init__(
id="c40d75a2-d0ea-44c9-a4f6-634bb3bdab1a",
description="Run Replicate models synchronously",
categories={BlockCategory.AI},
input_schema=ReplicateModelBlock.Input,
output_schema=ReplicateModelBlock.Output,
test_input={
"credentials": TEST_CREDENTIALS_INPUT,
"model_name": "meta/llama-2-7b-chat",
"model_inputs": {"prompt": "Hello, world!", "max_new_tokens": 50},
},
test_credentials=TEST_CREDENTIALS,
test_output=[
("result", str),
("status", str),
("model_name", str),
],
test_mock={
"run_model": lambda model_ref, model_inputs, api_key: (
"Mock response from Replicate model"
)
},
)
async def run(
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
) -> BlockOutput:
"""
Execute the Replicate model with the provided inputs.
Args:
input_data: The input data containing model name and inputs
credentials: The API credentials
Yields:
BlockOutput containing the model results and metadata
"""
try:
if input_data.version:
model_ref = f"{input_data.model_name}:{input_data.version}"
else:
model_ref = input_data.model_name
logger.debug(f"Running Replicate model: {model_ref}")
result = await self.run_model(
model_ref, input_data.model_inputs, credentials.api_key
)
yield "result", result
yield "status", "succeeded"
yield "model_name", input_data.model_name
except Exception as e:
error_msg = f"Unexpected error running Replicate model: {str(e)}"
logger.error(error_msg)
raise RuntimeError(error_msg)
async def run_model(self, model_ref: str, model_inputs: dict, api_key: SecretStr):
"""
Run the Replicate model. This method can be mocked for testing.
Args:
model_ref: The model reference (e.g., "owner/model-name:version")
model_inputs: The inputs to pass to the model
api_key: The Replicate API key as SecretStr
Returns:
Tuple of (result, prediction_id)
"""
api_key_str = api_key.get_secret_value()
client = ReplicateClient(api_token=api_key_str)
output: ReplicateOutputs = await client.async_run(
model_ref, input=model_inputs, wait=False
) # type: ignore they suck at typing
result = extract_result(output)
return result

View File

@@ -18,7 +18,8 @@ from backend.blocks.llm import (
AITextSummarizerBlock,
LlmModel,
)
from backend.blocks.replicate_flux_advanced import ReplicateFluxAdvancedModelBlock
from backend.blocks.replicate.flux_advanced import ReplicateFluxAdvancedModelBlock
from backend.blocks.replicate.replicate_block import ReplicateModelBlock
from backend.blocks.smart_decision_maker import SmartDecisionMakerBlock
from backend.blocks.talking_head import CreateTalkingAvatarVideoBlock
from backend.blocks.text_to_speech_block import UnrealTextToSpeechBlock
@@ -291,6 +292,18 @@ BLOCK_COSTS: dict[Type[Block], list[BlockCost]] = {
},
)
],
ReplicateModelBlock: [
BlockCost(
cost_amount=10,
cost_filter={
"credentials": {
"id": replicate_credentials.id,
"provider": replicate_credentials.provider,
"type": replicate_credentials.type,
}
},
)
],
AIImageEditorBlock: [
BlockCost(
cost_amount=10,