Compare commits

...

6 Commits

Author SHA1 Message Date
Toran Bruce Richards
e9a15cc028 Merge branch 'dev' into toran/open-2856-handle-failed-replicate-predictions-with-retries-in-all 2025-12-18 18:51:38 +00:00
Cursor Agent
10259387a3 Refactor: Improve replicate retry logic and remove unused import
Co-authored-by: nicholas.tindle <nicholas.tindle@agpt.co>
2025-12-02 18:20:28 +00:00
Toran Bruce Richards
0ad0776bca Merge branch 'dev' into toran/open-2856-handle-failed-replicate-predictions-with-retries-in-all 2025-12-02 17:42:11 +00:00
Nicholas Tindle
1faf903ab0 Merge branch 'dev' into toran/open-2856-handle-failed-replicate-predictions-with-retries-in-all 2025-12-01 13:32:01 -06:00
Toran Bruce Richards
f999c8ccdf Merge branch 'dev' into toran/open-2856-handle-failed-replicate-predictions-with-retries-in-all 2025-11-28 13:26:06 +00:00
Torantulino
3b24884fd7 refactor(backend/blocks): implement run_replicate_with_retry helper function
### Changes 🏗️
- Introduced a new helper function `run_replicate_with_retry` to handle retries for model execution across multiple blocks, improving error handling and reducing code duplication.
- Updated `AIImageCustomizerBlock`, `AIImageGeneratorBlock`, `AIMusicGeneratorBlock`, `AIImageEditorBlock`, `ReplicateFluxAdvancedModelBlock`, and `ReplicateModelBlock` to utilize the new helper function for running models.
2025-11-28 13:20:04 +00:00
7 changed files with 148 additions and 53 deletions

View File

@@ -6,6 +6,7 @@ from pydantic import SecretStr
from replicate.client import Client as ReplicateClient
from replicate.helpers import FileOutput
from backend.blocks.replicate._helper import run_replicate_with_retry
from backend.data.block import (
Block,
BlockCategory,
@@ -183,9 +184,10 @@ class AIImageCustomizerBlock(Block):
if images:
input_params["image_input"] = [str(img) for img in images]
output: FileOutput | str = await client.async_run( # type: ignore
output: FileOutput | str = await run_replicate_with_retry( # type: ignore
client,
model_name,
input=input_params,
input_params,
wait=False,
)

View File

@@ -5,6 +5,7 @@ from pydantic import SecretStr
from replicate.client import Client as ReplicateClient
from replicate.helpers import FileOutput
from backend.blocks.replicate._helper import run_replicate_with_retry
from backend.data.block import Block, BlockCategory, BlockSchemaInput, BlockSchemaOutput
from backend.data.model import (
APIKeyCredentials,
@@ -181,7 +182,9 @@ class AIImageGeneratorBlock(Block):
client = ReplicateClient(api_token=credentials.api_key.get_secret_value())
# Run the model with input parameters
output = await client.async_run(model_name, input=input_params, wait=False)
output = await run_replicate_with_retry(
client, model_name, input_params, wait=False
)
# Process output
if isinstance(output, list) and len(output) > 0:

View File

@@ -1,11 +1,12 @@
import asyncio
import logging
from enum import Enum
from typing import Literal
from pydantic import SecretStr
from replicate.client import Client as ReplicateClient
from replicate.helpers import FileOutput
from backend.blocks.replicate._helper import run_replicate_with_retry
from backend.data.block import (
Block,
BlockCategory,
@@ -43,12 +44,14 @@ class MusicGenModelVersion(str, Enum):
STEREO_LARGE = "stereo-large"
MELODY_LARGE = "melody-large"
LARGE = "large"
MINIMAX_MUSIC_1_5 = "minimax/music-1.5"
# Audio format enum
class AudioFormat(str, Enum):
WAV = "wav"
MP3 = "mp3"
PCM = "pcm"
# Normalization strategy enum
@@ -72,6 +75,14 @@ class AIMusicGeneratorBlock(Block):
placeholder="e.g., 'An upbeat electronic dance track with heavy bass'",
title="Prompt",
)
lyrics: str | None = SchemaField(
description=(
"Lyrics for the song (required for Minimax Music 1.5). "
"Use \\n to separate lines. Supports tags like [intro], [verse], [chorus], etc."
),
default=None,
title="Lyrics",
)
music_gen_model_version: MusicGenModelVersion = SchemaField(
description="Model to use for generation",
default=MusicGenModelVersion.STEREO_LARGE,
@@ -126,6 +137,7 @@ class AIMusicGeneratorBlock(Block):
test_input={
"credentials": TEST_CREDENTIALS_INPUT,
"prompt": "An upbeat electronic dance track with heavy bass",
"lyrics": None,
"music_gen_model_version": MusicGenModelVersion.STEREO_LARGE,
"duration": 8,
"temperature": 1.0,
@@ -142,7 +154,7 @@ class AIMusicGeneratorBlock(Block):
),
],
test_mock={
"run_model": lambda api_key, music_gen_model_version, prompt, duration, temperature, top_k, top_p, classifier_free_guidance, output_format, normalization_strategy: "https://replicate.com/output/generated-audio-url.wav",
"run_model": lambda api_key, music_gen_model_version, prompt, lyrics, duration, temperature, top_k, top_p, classifier_free_guidance, output_format, normalization_strategy: "https://replicate.com/output/generated-audio-url.wav",
},
test_credentials=TEST_CREDENTIALS,
)
@@ -150,48 +162,35 @@ class AIMusicGeneratorBlock(Block):
async def run(
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
) -> BlockOutput:
max_retries = 3
retry_delay = 5 # seconds
last_error = None
try:
result = await self.run_model(
api_key=credentials.api_key,
music_gen_model_version=input_data.music_gen_model_version,
prompt=input_data.prompt,
lyrics=input_data.lyrics,
duration=input_data.duration,
temperature=input_data.temperature,
top_k=input_data.top_k,
top_p=input_data.top_p,
classifier_free_guidance=input_data.classifier_free_guidance,
output_format=input_data.output_format,
normalization_strategy=input_data.normalization_strategy,
)
if result and isinstance(result, str) and result.startswith("http"):
yield "result", result
else:
yield "error", "Model returned empty or invalid response"
for attempt in range(max_retries):
try:
logger.debug(
f"[AIMusicGeneratorBlock] - Running model (attempt {attempt + 1})"
)
result = await self.run_model(
api_key=credentials.api_key,
music_gen_model_version=input_data.music_gen_model_version,
prompt=input_data.prompt,
duration=input_data.duration,
temperature=input_data.temperature,
top_k=input_data.top_k,
top_p=input_data.top_p,
classifier_free_guidance=input_data.classifier_free_guidance,
output_format=input_data.output_format,
normalization_strategy=input_data.normalization_strategy,
)
if result and isinstance(result, str) and result.startswith("http"):
yield "result", result
return
else:
last_error = "Model returned empty or invalid response"
raise ValueError(last_error)
except Exception as e:
last_error = f"Unexpected error: {str(e)}"
logger.error(f"[AIMusicGeneratorBlock] - Error: {last_error}")
if attempt < max_retries - 1:
await asyncio.sleep(retry_delay)
continue
# If we've exhausted all retries, yield the error
yield "error", f"Failed after {max_retries} attempts. Last error: {last_error}"
except Exception as e:
logger.error(f"[AIMusicGeneratorBlock] - Error: {str(e)}")
yield "error", f"Failed to generate music: {str(e)}"
async def run_model(
self,
api_key: SecretStr,
music_gen_model_version: MusicGenModelVersion,
prompt: str,
lyrics: str | None,
duration: int,
temperature: float,
top_k: int,
@@ -203,10 +202,24 @@ class AIMusicGeneratorBlock(Block):
# Initialize Replicate client with the API key
client = ReplicateClient(api_token=api_key.get_secret_value())
# Run the model with parameters
output = await client.async_run(
"meta/musicgen:671ac645ce5e552cc63a54a2bbff63fcf798043055d2dac5fc9e36a837eedcfb",
input={
if music_gen_model_version == MusicGenModelVersion.MINIMAX_MUSIC_1_5:
if not lyrics:
raise ValueError("Lyrics are required for Minimax Music 1.5 model")
# Validate prompt length (10-300 chars)
if len(prompt) < 10:
prompt = prompt.ljust(10, ".")
elif len(prompt) > 300:
prompt = prompt[:300]
input_params = {
"prompt": prompt,
"lyrics": lyrics,
"audio_format": output_format.value,
}
model_name = "minimax/music-1.5"
else:
input_params = {
"prompt": prompt,
"music_gen_model_version": music_gen_model_version,
"duration": duration,
@@ -216,7 +229,15 @@ class AIMusicGeneratorBlock(Block):
"classifier_free_guidance": classifier_free_guidance,
"output_format": output_format,
"normalization_strategy": normalization_strategy,
},
}
model_name = "meta/musicgen:671ac645ce5e552cc63a54a2bbff63fcf798043055d2dac5fc9e36a837eedcfb"
# Run the model with parameters
output = await run_replicate_with_retry(
client,
model_name,
input_params,
wait=True,
)
# Handle the output
@@ -224,6 +245,8 @@ class AIMusicGeneratorBlock(Block):
result_url = output[0] # If output is a list, get the first element
elif isinstance(output, str):
result_url = output # If output is a string, use it directly
elif isinstance(output, FileOutput):
result_url = output.url
else:
result_url = (
"No output received" # Fallback message if output is not as expected

View File

@@ -5,6 +5,7 @@ from pydantic import SecretStr
from replicate.client import Client as ReplicateClient
from replicate.helpers import FileOutput
from backend.blocks.replicate._helper import run_replicate_with_retry
from backend.data.block import (
Block,
BlockCategory,
@@ -173,9 +174,10 @@ class AIImageEditorBlock(Block):
**({"seed": seed} if seed is not None else {}),
}
output: FileOutput | list[FileOutput] = await client.async_run( # type: ignore
output: FileOutput | list[FileOutput] = await run_replicate_with_retry( # type: ignore
client,
model_name,
input=input_params,
input_params=input_params,
wait=False,
)

View File

@@ -1,5 +1,8 @@
import asyncio
import logging
from typing import Any
from replicate.client import Client as ReplicateClient
from replicate.helpers import FileOutput
logger = logging.getLogger(__name__)
@@ -37,3 +40,56 @@ def extract_result(output: ReplicateOutputs) -> str:
)
return result
async def run_replicate_with_retry(
client: ReplicateClient,
model: str,
input_params: dict[str, Any],
wait: bool = False,
max_retries: int = 3,
**kwargs: Any,
) -> Any:
last_error = None
retry_delay = 2 # seconds
for attempt in range(max_retries):
try:
output = await client.async_run(
model, input=input_params, wait=wait, **kwargs
)
# Check for failed status in response
is_failed = False
if isinstance(output, dict) and output.get("status") == "failed":
is_failed = True
elif hasattr(output, "status") and getattr(output, "status") == "failed":
is_failed = True
if is_failed:
# Try to get error message
error_msg = "Replicate prediction failed"
if isinstance(output, dict):
error = output.get("error")
if error:
error_msg = f"{error_msg}: {error}"
elif hasattr(output, "error"):
error = getattr(output, "error")
if error:
error_msg = f"{error_msg}: {error}"
raise RuntimeError(error_msg)
return output
except Exception as e:
last_error = e
if attempt < max_retries - 1:
wait_time = retry_delay * (2**attempt)
logger.warning(
f"Replicate attempt {attempt + 1} failed: {str(e)}. Retrying in {wait_time}s..."
)
await asyncio.sleep(wait_time)
else:
logger.error(f"Replicate failed after {max_retries} attempts: {str(e)}")
raise last_error

View File

@@ -9,7 +9,11 @@ from backend.blocks.replicate._auth import (
TEST_CREDENTIALS_INPUT,
ReplicateCredentialsInput,
)
from backend.blocks.replicate._helper import ReplicateOutputs, extract_result
from backend.blocks.replicate._helper import (
ReplicateOutputs,
extract_result,
run_replicate_with_retry,
)
from backend.data.block import (
Block,
BlockCategory,
@@ -188,9 +192,10 @@ class ReplicateFluxAdvancedModelBlock(Block):
client = ReplicateClient(api_token=api_key.get_secret_value())
# Run the model with additional parameters
output: ReplicateOutputs = await client.async_run( # type: ignore This is because they changed the return type, and didn't update the type hint! It should be overloaded depending on the value of `use_file_output` to `FileOutput | list[FileOutput]` but it's `Any | Iterator[Any]`
output: ReplicateOutputs = await run_replicate_with_retry( # type: ignore This is because they changed the return type, and didn't update the type hint! It should be overloaded depending on the value of `use_file_output` to `FileOutput | list[FileOutput]` but it's `Any | Iterator[Any]`
client,
f"{model_name}",
input={
input_params={
"prompt": prompt,
"seed": seed,
"steps": steps,

View File

@@ -9,7 +9,11 @@ from backend.blocks.replicate._auth import (
TEST_CREDENTIALS_INPUT,
ReplicateCredentialsInput,
)
from backend.blocks.replicate._helper import ReplicateOutputs, extract_result
from backend.blocks.replicate._helper import (
ReplicateOutputs,
extract_result,
run_replicate_with_retry,
)
from backend.data.block import (
Block,
BlockCategory,
@@ -129,8 +133,8 @@ class ReplicateModelBlock(Block):
"""
api_key_str = api_key.get_secret_value()
client = ReplicateClient(api_token=api_key_str)
output: ReplicateOutputs = await client.async_run(
model_ref, input=model_inputs, wait=False
output: ReplicateOutputs = await run_replicate_with_retry(
client, model_ref, input_params=model_inputs, wait=False
) # type: ignore they suck at typing
result = extract_result(output)