mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-02-09 22:35:54 -05:00
Compare commits
13 Commits
fix/execut
...
toran/flux
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a788aa26cc | ||
|
|
212e97767d | ||
|
|
3a743d946f | ||
|
|
ea6b171215 | ||
|
|
75265b16e4 | ||
|
|
bf851f371b | ||
|
|
deb64d8da1 | ||
|
|
ca54d06090 | ||
|
|
cf67551a5f | ||
|
|
e0cdfff030 | ||
|
|
133ed10ecf | ||
|
|
4d0ac7a4c9 | ||
|
|
dc61e784f3 |
@@ -77,6 +77,9 @@ MEDIUM_AUTHOR_ID=
|
|||||||
# Google Maps
|
# Google Maps
|
||||||
GOOGLE_MAPS_API_KEY=
|
GOOGLE_MAPS_API_KEY=
|
||||||
|
|
||||||
|
# Replicate
|
||||||
|
REPLICATE_API_KEY=
|
||||||
|
|
||||||
# Logging Configuration
|
# Logging Configuration
|
||||||
LOG_LEVEL=INFO
|
LOG_LEVEL=INFO
|
||||||
ENABLE_CLOUD_LOGGING=false
|
ENABLE_CLOUD_LOGGING=false
|
||||||
|
|||||||
@@ -0,0 +1,204 @@
|
|||||||
|
import os
|
||||||
|
from enum import Enum
|
||||||
|
|
||||||
|
import replicate
|
||||||
|
|
||||||
|
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||||
|
from backend.data.model import BlockSecret, SchemaField, SecretField
|
||||||
|
|
||||||
|
|
||||||
|
# Model name enum
|
||||||
|
class ReplicateFluxModelName(str, Enum):
|
||||||
|
FLUX_SCHNELL = ("Flux Schnell",)
|
||||||
|
FLUX_PRO = ("Flux Pro",)
|
||||||
|
FLUX_DEV = ("Flux Dev",)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def api_name(self):
|
||||||
|
api_names = {
|
||||||
|
ReplicateFluxModelName.FLUX_SCHNELL: "black-forest-labs/flux-schnell",
|
||||||
|
ReplicateFluxModelName.FLUX_PRO: "black-forest-labs/flux-pro",
|
||||||
|
ReplicateFluxModelName.FLUX_DEV: "black-forest-labs/flux-dev",
|
||||||
|
}
|
||||||
|
return api_names[self]
|
||||||
|
|
||||||
|
|
||||||
|
# Image type Enum
|
||||||
|
class ImageType(str, Enum):
|
||||||
|
WEBP = "webp"
|
||||||
|
JPG = "jpg"
|
||||||
|
PNG = "png"
|
||||||
|
|
||||||
|
|
||||||
|
class ReplicateFluxAdvancedModelBlock(Block):
|
||||||
|
class Input(BlockSchema):
|
||||||
|
api_key: BlockSecret = SecretField(
|
||||||
|
key="replicate_api_key",
|
||||||
|
description="Replicate API Key",
|
||||||
|
)
|
||||||
|
prompt: str = SchemaField(
|
||||||
|
description="Text prompt for image generation",
|
||||||
|
placeholder="e.g., 'A futuristic cityscape at sunset'",
|
||||||
|
title="Prompt",
|
||||||
|
)
|
||||||
|
replicate_model_name: ReplicateFluxModelName = SchemaField(
|
||||||
|
description="The name of the Image Generation Model, i.e Flux Schnell",
|
||||||
|
default=ReplicateFluxModelName.FLUX_SCHNELL,
|
||||||
|
title="Image Generation Model",
|
||||||
|
advanced=False,
|
||||||
|
)
|
||||||
|
seed: int | None = SchemaField(
|
||||||
|
description="Random seed. Set for reproducible generation",
|
||||||
|
default=None,
|
||||||
|
title="Seed",
|
||||||
|
)
|
||||||
|
steps: int = SchemaField(
|
||||||
|
description="Number of diffusion steps",
|
||||||
|
default=25,
|
||||||
|
title="Steps",
|
||||||
|
)
|
||||||
|
guidance: float = SchemaField(
|
||||||
|
description=(
|
||||||
|
"Controls the balance between adherence to the text prompt and image quality/diversity. "
|
||||||
|
"Higher values make the output more closely match the prompt but may reduce overall image quality."
|
||||||
|
),
|
||||||
|
default=3,
|
||||||
|
title="Guidance",
|
||||||
|
)
|
||||||
|
interval: float = SchemaField(
|
||||||
|
description=(
|
||||||
|
"Interval is a setting that increases the variance in possible outputs. "
|
||||||
|
"Setting this value low will ensure strong prompt following with more consistent outputs."
|
||||||
|
),
|
||||||
|
default=2,
|
||||||
|
title="Interval",
|
||||||
|
)
|
||||||
|
aspect_ratio: str = SchemaField(
|
||||||
|
description="Aspect ratio for the generated image",
|
||||||
|
default="1:1",
|
||||||
|
title="Aspect Ratio",
|
||||||
|
placeholder="Choose from: 1:1, 16:9, 2:3, 3:2, 4:5, 5:4, 9:16",
|
||||||
|
)
|
||||||
|
output_format: ImageType = SchemaField(
|
||||||
|
description="File format of the output image",
|
||||||
|
default=ImageType.WEBP,
|
||||||
|
title="Output Format",
|
||||||
|
)
|
||||||
|
output_quality: int = SchemaField(
|
||||||
|
description=(
|
||||||
|
"Quality when saving the output images, from 0 to 100. "
|
||||||
|
"Not relevant for .png outputs"
|
||||||
|
),
|
||||||
|
default=80,
|
||||||
|
title="Output Quality",
|
||||||
|
)
|
||||||
|
safety_tolerance: int = SchemaField(
|
||||||
|
description="Safety tolerance, 1 is most strict and 5 is most permissive",
|
||||||
|
default=2,
|
||||||
|
title="Safety Tolerance",
|
||||||
|
)
|
||||||
|
|
||||||
|
class Output(BlockSchema):
|
||||||
|
result: str = SchemaField(description="Generated output")
|
||||||
|
error: str = SchemaField(description="Error message if the model run failed")
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(
|
||||||
|
id="90f8c45e-e983-4644-aa0b-b4ebe2f531bc",
|
||||||
|
description="This block runs Flux models on Replicate with advanced settings.",
|
||||||
|
categories={BlockCategory.AI},
|
||||||
|
input_schema=ReplicateFluxAdvancedModelBlock.Input,
|
||||||
|
output_schema=ReplicateFluxAdvancedModelBlock.Output,
|
||||||
|
test_input={
|
||||||
|
"api_key": "test_api_key",
|
||||||
|
"replicate_model_name": ReplicateFluxModelName.FLUX_SCHNELL,
|
||||||
|
"prompt": "A beautiful landscape painting of a serene lake at sunrise",
|
||||||
|
"seed": None,
|
||||||
|
"steps": 25,
|
||||||
|
"guidance": 3.0,
|
||||||
|
"interval": 2.0,
|
||||||
|
"aspect_ratio": "1:1",
|
||||||
|
"output_format": ImageType.PNG,
|
||||||
|
"output_quality": 80,
|
||||||
|
"safety_tolerance": 2,
|
||||||
|
},
|
||||||
|
test_output=[
|
||||||
|
(
|
||||||
|
"result",
|
||||||
|
"https://replicate.com/output/generated-image-url.jpg",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
test_mock={
|
||||||
|
"run_model": lambda api_key, model_name, prompt, seed, steps, guidance, interval, aspect_ratio, output_format, output_quality, safety_tolerance: "https://replicate.com/output/generated-image-url.jpg",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||||
|
# If the seed is not provided, generate a random seed
|
||||||
|
seed = input_data.seed
|
||||||
|
if seed is None:
|
||||||
|
seed = int.from_bytes(os.urandom(4), "big")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Run the model using the provided inputs
|
||||||
|
result = self.run_model(
|
||||||
|
api_key=input_data.api_key.get_secret_value(),
|
||||||
|
model_name=input_data.replicate_model_name.api_name,
|
||||||
|
prompt=input_data.prompt,
|
||||||
|
seed=seed,
|
||||||
|
steps=input_data.steps,
|
||||||
|
guidance=input_data.guidance,
|
||||||
|
interval=input_data.interval,
|
||||||
|
aspect_ratio=input_data.aspect_ratio,
|
||||||
|
output_format=input_data.output_format,
|
||||||
|
output_quality=input_data.output_quality,
|
||||||
|
safety_tolerance=input_data.safety_tolerance,
|
||||||
|
)
|
||||||
|
yield "result", result
|
||||||
|
except Exception as e:
|
||||||
|
yield "error", str(e)
|
||||||
|
|
||||||
|
def run_model(
|
||||||
|
self,
|
||||||
|
api_key,
|
||||||
|
model_name,
|
||||||
|
prompt,
|
||||||
|
seed,
|
||||||
|
steps,
|
||||||
|
guidance,
|
||||||
|
interval,
|
||||||
|
aspect_ratio,
|
||||||
|
output_format,
|
||||||
|
output_quality,
|
||||||
|
safety_tolerance,
|
||||||
|
):
|
||||||
|
# Initialize Replicate client with the API key
|
||||||
|
client = replicate.Client(api_token=api_key)
|
||||||
|
|
||||||
|
# Run the model with additional parameters
|
||||||
|
output = client.run(
|
||||||
|
f"{model_name}",
|
||||||
|
input={
|
||||||
|
"prompt": prompt,
|
||||||
|
"seed": seed,
|
||||||
|
"steps": steps,
|
||||||
|
"guidance": guidance,
|
||||||
|
"interval": interval,
|
||||||
|
"aspect_ratio": aspect_ratio,
|
||||||
|
"output_format": output_format,
|
||||||
|
"output_quality": output_quality,
|
||||||
|
"safety_tolerance": safety_tolerance,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check if output is a list or a string and extract accordingly; otherwise, assign a default message
|
||||||
|
if isinstance(output, list) and len(output) > 0:
|
||||||
|
result_url = output[0] # If output is a list, get the first element
|
||||||
|
elif isinstance(output, str):
|
||||||
|
result_url = output # If output is a string, use it directly
|
||||||
|
else:
|
||||||
|
result_url = (
|
||||||
|
"No output received" # Fallback message if output is not as expected
|
||||||
|
)
|
||||||
|
|
||||||
|
return result_url
|
||||||
@@ -218,6 +218,7 @@ class Secrets(UpdateTrackingModel["Secrets"], BaseSettings):
|
|||||||
|
|
||||||
google_maps_api_key: str = Field(default="", description="Google Maps API Key")
|
google_maps_api_key: str = Field(default="", description="Google Maps API Key")
|
||||||
|
|
||||||
|
replicate_api_key: str = Field(default="", description="Replicate API Key")
|
||||||
# Add more secret fields as needed
|
# Add more secret fields as needed
|
||||||
|
|
||||||
model_config = SettingsConfigDict(
|
model_config = SettingsConfigDict(
|
||||||
|
|||||||
909
autogpt_platform/backend/poetry.lock
generated
909
autogpt_platform/backend/poetry.lock
generated
File diff suppressed because it is too large
Load Diff
@@ -45,6 +45,7 @@ uvicorn = { extras = ["standard"], version = "^0.30.1" }
|
|||||||
websockets = "^12.0"
|
websockets = "^12.0"
|
||||||
youtube-transcript-api = "^0.6.2"
|
youtube-transcript-api = "^0.6.2"
|
||||||
googlemaps = "^4.10.0"
|
googlemaps = "^4.10.0"
|
||||||
|
replicate = "^0.34.1"
|
||||||
|
|
||||||
[tool.poetry.group.dev.dependencies]
|
[tool.poetry.group.dev.dependencies]
|
||||||
poethepoet = "^0.26.1"
|
poethepoet = "^0.26.1"
|
||||||
|
|||||||
Reference in New Issue
Block a user