Refactor ReplicateFluxAdvancedModelBlock to use an enum for replicate_model_name rather than free strings.

This commit is contained in:
Toran Bruce Richards
2024-09-29 19:17:34 +01:00
parent 4d0ac7a4c9
commit 133ed10ecf

View File

@@ -1,18 +1,37 @@
from enum import Enum
import replicate
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
from backend.data.model import BlockSecret, SchemaField, SecretField
# Model name enum
class ReplicateFluxModelName(str, Enum):
FLUX_SCHNELL = ("Flux Schnell",)
FLUX_PRO = ("Flux Pro",)
FLUX_DEV = ("Flux Dev",)
@property
def api_name(self):
api_names = {
ReplicateFluxModelName.FLUX_SCHNELL: "black-forest-labs/flux-schnell",
ReplicateFluxModelName.FLUX_PRO: "black-forest-labs/flux-pro",
ReplicateFluxModelName.FLUX_DEV: "black-forest-labs/flux-dev",
}
return api_names[self]
class ReplicateFluxAdvancedModelBlock(Block):
class Input(BlockSchema):
api_key: BlockSecret = SecretField(
key="replicate_api_key",
description="Replicate API Key",
)
replicate_model_name: str = SchemaField(
description="The name of the model on Replicate (e.g., 'black-forest-labs/flux-schnell')",
placeholder="e.g., 'black-forest-labs/flux-schnell'",
title="Replicate Model Name",
replicate_model_name: ReplicateFluxModelName = SchemaField(
description="The name of the Image Generation Model, i.e Flux Schnell",
default=ReplicateFluxModelName.FLUX_SCHNELL,
title="Image Generation Model",
)
prompt: str = SchemaField(
description="Text prompt for image generation",
@@ -83,7 +102,7 @@ class ReplicateFluxAdvancedModelBlock(Block):
output_schema=ReplicateFluxAdvancedModelBlock.Output,
test_input={
"api_key": "your_test_api_key",
"replicate_model_name": "black-forest-labs/flux-schnell",
"replicate_model_name": ReplicateFluxModelName.FLUX_SCHNELL,
"prompt": "A beautiful landscape painting of a serene lake at sunrise",
"steps": 25,
"guidance": 3.0,