mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-02-09 22:35:54 -05:00
Compare commits
6 Commits
dependabot
...
toran/open
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e9a15cc028 | ||
|
|
10259387a3 | ||
|
|
0ad0776bca | ||
|
|
1faf903ab0 | ||
|
|
f999c8ccdf | ||
|
|
3b24884fd7 |
@@ -6,6 +6,7 @@ from pydantic import SecretStr
|
||||
from replicate.client import Client as ReplicateClient
|
||||
from replicate.helpers import FileOutput
|
||||
|
||||
from backend.blocks.replicate._helper import run_replicate_with_retry
|
||||
from backend.data.block import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
@@ -183,9 +184,10 @@ class AIImageCustomizerBlock(Block):
|
||||
if images:
|
||||
input_params["image_input"] = [str(img) for img in images]
|
||||
|
||||
output: FileOutput | str = await client.async_run( # type: ignore
|
||||
output: FileOutput | str = await run_replicate_with_retry( # type: ignore
|
||||
client,
|
||||
model_name,
|
||||
input=input_params,
|
||||
input_params,
|
||||
wait=False,
|
||||
)
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@ from pydantic import SecretStr
|
||||
from replicate.client import Client as ReplicateClient
|
||||
from replicate.helpers import FileOutput
|
||||
|
||||
from backend.blocks.replicate._helper import run_replicate_with_retry
|
||||
from backend.data.block import Block, BlockCategory, BlockSchemaInput, BlockSchemaOutput
|
||||
from backend.data.model import (
|
||||
APIKeyCredentials,
|
||||
@@ -181,7 +182,9 @@ class AIImageGeneratorBlock(Block):
|
||||
client = ReplicateClient(api_token=credentials.api_key.get_secret_value())
|
||||
|
||||
# Run the model with input parameters
|
||||
output = await client.async_run(model_name, input=input_params, wait=False)
|
||||
output = await run_replicate_with_retry(
|
||||
client, model_name, input_params, wait=False
|
||||
)
|
||||
|
||||
# Process output
|
||||
if isinstance(output, list) and len(output) > 0:
|
||||
|
||||
@@ -1,11 +1,12 @@
|
||||
import asyncio
|
||||
import logging
|
||||
from enum import Enum
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import SecretStr
|
||||
from replicate.client import Client as ReplicateClient
|
||||
from replicate.helpers import FileOutput
|
||||
|
||||
from backend.blocks.replicate._helper import run_replicate_with_retry
|
||||
from backend.data.block import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
@@ -43,12 +44,14 @@ class MusicGenModelVersion(str, Enum):
|
||||
STEREO_LARGE = "stereo-large"
|
||||
MELODY_LARGE = "melody-large"
|
||||
LARGE = "large"
|
||||
MINIMAX_MUSIC_1_5 = "minimax/music-1.5"
|
||||
|
||||
|
||||
# Audio format enum
|
||||
class AudioFormat(str, Enum):
|
||||
WAV = "wav"
|
||||
MP3 = "mp3"
|
||||
PCM = "pcm"
|
||||
|
||||
|
||||
# Normalization strategy enum
|
||||
@@ -72,6 +75,14 @@ class AIMusicGeneratorBlock(Block):
|
||||
placeholder="e.g., 'An upbeat electronic dance track with heavy bass'",
|
||||
title="Prompt",
|
||||
)
|
||||
lyrics: str | None = SchemaField(
|
||||
description=(
|
||||
"Lyrics for the song (required for Minimax Music 1.5). "
|
||||
"Use \\n to separate lines. Supports tags like [intro], [verse], [chorus], etc."
|
||||
),
|
||||
default=None,
|
||||
title="Lyrics",
|
||||
)
|
||||
music_gen_model_version: MusicGenModelVersion = SchemaField(
|
||||
description="Model to use for generation",
|
||||
default=MusicGenModelVersion.STEREO_LARGE,
|
||||
@@ -126,6 +137,7 @@ class AIMusicGeneratorBlock(Block):
|
||||
test_input={
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"prompt": "An upbeat electronic dance track with heavy bass",
|
||||
"lyrics": None,
|
||||
"music_gen_model_version": MusicGenModelVersion.STEREO_LARGE,
|
||||
"duration": 8,
|
||||
"temperature": 1.0,
|
||||
@@ -142,7 +154,7 @@ class AIMusicGeneratorBlock(Block):
|
||||
),
|
||||
],
|
||||
test_mock={
|
||||
"run_model": lambda api_key, music_gen_model_version, prompt, duration, temperature, top_k, top_p, classifier_free_guidance, output_format, normalization_strategy: "https://replicate.com/output/generated-audio-url.wav",
|
||||
"run_model": lambda api_key, music_gen_model_version, prompt, lyrics, duration, temperature, top_k, top_p, classifier_free_guidance, output_format, normalization_strategy: "https://replicate.com/output/generated-audio-url.wav",
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
)
|
||||
@@ -150,48 +162,35 @@ class AIMusicGeneratorBlock(Block):
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: APIKeyCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
max_retries = 3
|
||||
retry_delay = 5 # seconds
|
||||
last_error = None
|
||||
try:
|
||||
result = await self.run_model(
|
||||
api_key=credentials.api_key,
|
||||
music_gen_model_version=input_data.music_gen_model_version,
|
||||
prompt=input_data.prompt,
|
||||
lyrics=input_data.lyrics,
|
||||
duration=input_data.duration,
|
||||
temperature=input_data.temperature,
|
||||
top_k=input_data.top_k,
|
||||
top_p=input_data.top_p,
|
||||
classifier_free_guidance=input_data.classifier_free_guidance,
|
||||
output_format=input_data.output_format,
|
||||
normalization_strategy=input_data.normalization_strategy,
|
||||
)
|
||||
if result and isinstance(result, str) and result.startswith("http"):
|
||||
yield "result", result
|
||||
else:
|
||||
yield "error", "Model returned empty or invalid response"
|
||||
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
logger.debug(
|
||||
f"[AIMusicGeneratorBlock] - Running model (attempt {attempt + 1})"
|
||||
)
|
||||
result = await self.run_model(
|
||||
api_key=credentials.api_key,
|
||||
music_gen_model_version=input_data.music_gen_model_version,
|
||||
prompt=input_data.prompt,
|
||||
duration=input_data.duration,
|
||||
temperature=input_data.temperature,
|
||||
top_k=input_data.top_k,
|
||||
top_p=input_data.top_p,
|
||||
classifier_free_guidance=input_data.classifier_free_guidance,
|
||||
output_format=input_data.output_format,
|
||||
normalization_strategy=input_data.normalization_strategy,
|
||||
)
|
||||
if result and isinstance(result, str) and result.startswith("http"):
|
||||
yield "result", result
|
||||
return
|
||||
else:
|
||||
last_error = "Model returned empty or invalid response"
|
||||
raise ValueError(last_error)
|
||||
except Exception as e:
|
||||
last_error = f"Unexpected error: {str(e)}"
|
||||
logger.error(f"[AIMusicGeneratorBlock] - Error: {last_error}")
|
||||
if attempt < max_retries - 1:
|
||||
await asyncio.sleep(retry_delay)
|
||||
continue
|
||||
|
||||
# If we've exhausted all retries, yield the error
|
||||
yield "error", f"Failed after {max_retries} attempts. Last error: {last_error}"
|
||||
except Exception as e:
|
||||
logger.error(f"[AIMusicGeneratorBlock] - Error: {str(e)}")
|
||||
yield "error", f"Failed to generate music: {str(e)}"
|
||||
|
||||
async def run_model(
|
||||
self,
|
||||
api_key: SecretStr,
|
||||
music_gen_model_version: MusicGenModelVersion,
|
||||
prompt: str,
|
||||
lyrics: str | None,
|
||||
duration: int,
|
||||
temperature: float,
|
||||
top_k: int,
|
||||
@@ -203,10 +202,24 @@ class AIMusicGeneratorBlock(Block):
|
||||
# Initialize Replicate client with the API key
|
||||
client = ReplicateClient(api_token=api_key.get_secret_value())
|
||||
|
||||
# Run the model with parameters
|
||||
output = await client.async_run(
|
||||
"meta/musicgen:671ac645ce5e552cc63a54a2bbff63fcf798043055d2dac5fc9e36a837eedcfb",
|
||||
input={
|
||||
if music_gen_model_version == MusicGenModelVersion.MINIMAX_MUSIC_1_5:
|
||||
if not lyrics:
|
||||
raise ValueError("Lyrics are required for Minimax Music 1.5 model")
|
||||
|
||||
# Validate prompt length (10-300 chars)
|
||||
if len(prompt) < 10:
|
||||
prompt = prompt.ljust(10, ".")
|
||||
elif len(prompt) > 300:
|
||||
prompt = prompt[:300]
|
||||
|
||||
input_params = {
|
||||
"prompt": prompt,
|
||||
"lyrics": lyrics,
|
||||
"audio_format": output_format.value,
|
||||
}
|
||||
model_name = "minimax/music-1.5"
|
||||
else:
|
||||
input_params = {
|
||||
"prompt": prompt,
|
||||
"music_gen_model_version": music_gen_model_version,
|
||||
"duration": duration,
|
||||
@@ -216,7 +229,15 @@ class AIMusicGeneratorBlock(Block):
|
||||
"classifier_free_guidance": classifier_free_guidance,
|
||||
"output_format": output_format,
|
||||
"normalization_strategy": normalization_strategy,
|
||||
},
|
||||
}
|
||||
model_name = "meta/musicgen:671ac645ce5e552cc63a54a2bbff63fcf798043055d2dac5fc9e36a837eedcfb"
|
||||
|
||||
# Run the model with parameters
|
||||
output = await run_replicate_with_retry(
|
||||
client,
|
||||
model_name,
|
||||
input_params,
|
||||
wait=True,
|
||||
)
|
||||
|
||||
# Handle the output
|
||||
@@ -224,6 +245,8 @@ class AIMusicGeneratorBlock(Block):
|
||||
result_url = output[0] # If output is a list, get the first element
|
||||
elif isinstance(output, str):
|
||||
result_url = output # If output is a string, use it directly
|
||||
elif isinstance(output, FileOutput):
|
||||
result_url = output.url
|
||||
else:
|
||||
result_url = (
|
||||
"No output received" # Fallback message if output is not as expected
|
||||
|
||||
@@ -5,6 +5,7 @@ from pydantic import SecretStr
|
||||
from replicate.client import Client as ReplicateClient
|
||||
from replicate.helpers import FileOutput
|
||||
|
||||
from backend.blocks.replicate._helper import run_replicate_with_retry
|
||||
from backend.data.block import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
@@ -173,9 +174,10 @@ class AIImageEditorBlock(Block):
|
||||
**({"seed": seed} if seed is not None else {}),
|
||||
}
|
||||
|
||||
output: FileOutput | list[FileOutput] = await client.async_run( # type: ignore
|
||||
output: FileOutput | list[FileOutput] = await run_replicate_with_retry( # type: ignore
|
||||
client,
|
||||
model_name,
|
||||
input=input_params,
|
||||
input_params=input_params,
|
||||
wait=False,
|
||||
)
|
||||
|
||||
|
||||
@@ -1,5 +1,8 @@
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from replicate.client import Client as ReplicateClient
|
||||
from replicate.helpers import FileOutput
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -37,3 +40,56 @@ def extract_result(output: ReplicateOutputs) -> str:
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
async def run_replicate_with_retry(
|
||||
client: ReplicateClient,
|
||||
model: str,
|
||||
input_params: dict[str, Any],
|
||||
wait: bool = False,
|
||||
max_retries: int = 3,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
last_error = None
|
||||
retry_delay = 2 # seconds
|
||||
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
output = await client.async_run(
|
||||
model, input=input_params, wait=wait, **kwargs
|
||||
)
|
||||
|
||||
# Check for failed status in response
|
||||
is_failed = False
|
||||
if isinstance(output, dict) and output.get("status") == "failed":
|
||||
is_failed = True
|
||||
elif hasattr(output, "status") and getattr(output, "status") == "failed":
|
||||
is_failed = True
|
||||
|
||||
if is_failed:
|
||||
# Try to get error message
|
||||
error_msg = "Replicate prediction failed"
|
||||
if isinstance(output, dict):
|
||||
error = output.get("error")
|
||||
if error:
|
||||
error_msg = f"{error_msg}: {error}"
|
||||
elif hasattr(output, "error"):
|
||||
error = getattr(output, "error")
|
||||
if error:
|
||||
error_msg = f"{error_msg}: {error}"
|
||||
|
||||
raise RuntimeError(error_msg)
|
||||
|
||||
return output
|
||||
|
||||
except Exception as e:
|
||||
last_error = e
|
||||
if attempt < max_retries - 1:
|
||||
wait_time = retry_delay * (2**attempt)
|
||||
logger.warning(
|
||||
f"Replicate attempt {attempt + 1} failed: {str(e)}. Retrying in {wait_time}s..."
|
||||
)
|
||||
await asyncio.sleep(wait_time)
|
||||
else:
|
||||
logger.error(f"Replicate failed after {max_retries} attempts: {str(e)}")
|
||||
raise last_error
|
||||
|
||||
@@ -9,7 +9,11 @@ from backend.blocks.replicate._auth import (
|
||||
TEST_CREDENTIALS_INPUT,
|
||||
ReplicateCredentialsInput,
|
||||
)
|
||||
from backend.blocks.replicate._helper import ReplicateOutputs, extract_result
|
||||
from backend.blocks.replicate._helper import (
|
||||
ReplicateOutputs,
|
||||
extract_result,
|
||||
run_replicate_with_retry,
|
||||
)
|
||||
from backend.data.block import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
@@ -188,9 +192,10 @@ class ReplicateFluxAdvancedModelBlock(Block):
|
||||
client = ReplicateClient(api_token=api_key.get_secret_value())
|
||||
|
||||
# Run the model with additional parameters
|
||||
output: ReplicateOutputs = await client.async_run( # type: ignore This is because they changed the return type, and didn't update the type hint! It should be overloaded depending on the value of `use_file_output` to `FileOutput | list[FileOutput]` but it's `Any | Iterator[Any]`
|
||||
output: ReplicateOutputs = await run_replicate_with_retry( # type: ignore This is because they changed the return type, and didn't update the type hint! It should be overloaded depending on the value of `use_file_output` to `FileOutput | list[FileOutput]` but it's `Any | Iterator[Any]`
|
||||
client,
|
||||
f"{model_name}",
|
||||
input={
|
||||
input_params={
|
||||
"prompt": prompt,
|
||||
"seed": seed,
|
||||
"steps": steps,
|
||||
|
||||
@@ -9,7 +9,11 @@ from backend.blocks.replicate._auth import (
|
||||
TEST_CREDENTIALS_INPUT,
|
||||
ReplicateCredentialsInput,
|
||||
)
|
||||
from backend.blocks.replicate._helper import ReplicateOutputs, extract_result
|
||||
from backend.blocks.replicate._helper import (
|
||||
ReplicateOutputs,
|
||||
extract_result,
|
||||
run_replicate_with_retry,
|
||||
)
|
||||
from backend.data.block import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
@@ -129,8 +133,8 @@ class ReplicateModelBlock(Block):
|
||||
"""
|
||||
api_key_str = api_key.get_secret_value()
|
||||
client = ReplicateClient(api_token=api_key_str)
|
||||
output: ReplicateOutputs = await client.async_run(
|
||||
model_ref, input=model_inputs, wait=False
|
||||
output: ReplicateOutputs = await run_replicate_with_retry(
|
||||
client, model_ref, input_params=model_inputs, wait=False
|
||||
) # type: ignore they suck at typing
|
||||
|
||||
result = extract_result(output)
|
||||
|
||||
Reference in New Issue
Block a user