Refactor and reformat Redis block implementation

This commit is contained in:
abhi1992002
2025-03-30 09:47:14 +05:30
parent 0cbf365edb
commit 3a2d52b252

View File

@@ -1,7 +1,8 @@
from pydantic import SecretStr
import redis
from enum import Enum
from typing import Optional, Literal
from typing import Literal, Optional
import redis
from pydantic import SecretStr
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
from backend.data.model import (
@@ -18,12 +19,14 @@ RedisCredentialsInput = CredentialsMetaInput[
Literal["user_password"],
]
def RedisCredentialsField() -> RedisCredentialsInput:
"""Creates a Redis credentials input on a block."""
return CredentialsField(
description="Redis connection credentials",
)
TEST_REDIS_CREDENTIALS = UserPasswordCredentials(
id="01234567-89ab-cdef-0123-456789abcdef",
provider="redis",
@@ -39,35 +42,34 @@ TEST_REDIS_CREDENTIALS_INPUT = {
"title": TEST_REDIS_CREDENTIALS.title,
}
class ListDirection(str, Enum):
LEFT = "LEFT"
RIGHT = "RIGHT"
class SetAction(str, Enum):
ADD = "ADD"
REMOVE = "REMOVE"
class SetQueryAction(str, Enum):
GET_ALL = "GET_ALL" # Corresponds to SMEMBERS
IS_MEMBER = "IS_MEMBER" # Corresponds to SISMEMBER
class SetCondition(str, Enum):
NX = "NX" # Only set the key if it does not already exist.
XX = "XX" # Only set the key if it already exist.
NX = "NX" # Only set the key if it does not already exist.
XX = "XX" # Only set the key if it already exist.
class RedisGetBlock(Block):
"""Retrieves the value stored at a specific key."""
class Input(BlockSchema):
credentials: RedisCredentialsInput = RedisCredentialsField()
host: str = SchemaField(description="Redis server host address")
port: int = SchemaField(description="Redis server port", default=6379,advanced=False)
port: int = SchemaField(
description="Redis server port", default=6379, advanced=False
)
key: str = SchemaField(description="The key whose value to retrieve.")
class Output(BlockSchema):
success: bool = SchemaField(description="Operation succeeded.")
value: Optional[str] = SchemaField(description="The value stored at the key, or None if the key doesn't exist.")
error:str = SchemaField(description="Error message if operation failed.")
value: Optional[str] = SchemaField(
description="The value stored at the key, or None if the key doesn't exist."
)
error: str = SchemaField(description="Error message if operation failed.")
def __init__(self):
super().__init__(
@@ -81,16 +83,13 @@ class RedisGetBlock(Block):
"credentials": TEST_REDIS_CREDENTIALS_INPUT,
"host": "localhost",
"port": 6379,
"key": "my_test_key"
"key": "my_test_key",
},
test_output=[
("success", True),
("value", "my_test_value")
],
test_output=[("success", True), ("value", "my_test_value")],
test_mock={
"run": lambda *args, **kwargs: [
("success", True),
("value", "my_test_value")
("value", "my_test_value"),
]
},
)
@@ -102,9 +101,17 @@ class RedisGetBlock(Block):
with redis.Redis(
host=input_data.host,
port=input_data.port,
username=credentials.username.get_secret_value() if credentials.username else "default",
password=credentials.password.get_secret_value() if credentials.password else None,
decode_responses=True # Decode from bytes to str automatically
username=(
credentials.username.get_secret_value()
if credentials.username
else "default"
),
password=(
credentials.password.get_secret_value()
if credentials.password
else None
),
decode_responses=True, # Decode from bytes to str automatically
) as r:
value = r.get(input_data.key)
@@ -114,28 +121,36 @@ class RedisGetBlock(Block):
yield "success", False
yield "error", str(e)
class RedisSetBlock(Block):
"""Stores or updates a value for a key, optionally setting an expiry time or conditions."""
class Input(BlockSchema):
credentials: RedisCredentialsInput = RedisCredentialsField()
host: str = SchemaField(description="Redis server host address")
port: int = SchemaField(description="Redis server port", default=6379, advanced=False)
port: int = SchemaField(
description="Redis server port", default=6379, advanced=False
)
key: str = SchemaField(description="The key to set.")
value: str = SchemaField(description="The value to store.")
expiration_ms: Optional[int] = SchemaField(
default=None,
description="Optional expiration time in milliseconds (PX).",
advanced=True
advanced=True,
)
condition: Optional[SetCondition] = SchemaField(
default=None,
description="Set condition: NX (Not Exists) or XX (Exists).",
advanced=True
description="Set condition: NX (Not Exists) or XX (Exists). NX = Only set the key if it does not already exist. XX = Only set the key if it already exist.",
advanced=True,
)
class Output(BlockSchema):
success: bool = SchemaField(description="True if the SET operation was successful.")
key_was_set: Optional[bool] = SchemaField(description="True if the key was actually set (especially relevant with NX/XX). Can be None if command fails early.")
success: bool = SchemaField(
description="True if the SET operation was successful."
)
key_was_set: Optional[bool] = SchemaField(
description="True if the key was actually set (especially relevant with NX/XX). Can be None if command fails early."
)
error: str = SchemaField(description="Error message if operation failed.")
def __init__(self):
@@ -152,17 +167,14 @@ class RedisSetBlock(Block):
"port": 6379,
"key": "my_set_key",
"value": "some data",
"expiration_ms": 60000, # 1 minute
"expiration_ms": 60000, # 1 minute
"condition": None,
},
test_output=[
("success", True),
("key_was_set", True)
],
test_mock={
test_output=[("success", True), ("key_was_set", True)],
test_mock={
"run": lambda *args, **kwargs: [
("success", True),
("key_was_set", True)
("key_was_set", True),
]
},
)
@@ -174,17 +186,33 @@ class RedisSetBlock(Block):
with redis.Redis(
host=input_data.host,
port=input_data.port,
username=credentials.username.get_secret_value() if credentials.username else "default",
password=credentials.password.get_secret_value() if credentials.password else None,
decode_responses=True
username=(
credentials.username.get_secret_value()
if credentials.username
else "default"
),
password=(
credentials.password.get_secret_value()
if credentials.password
else None
),
decode_responses=True,
) as r:
# Prepare arguments for set command
set_args = {
'name': input_data.key,
'value': input_data.value,
'px': input_data.expiration_ms,
'nx': input_data.condition == SetCondition.NX if input_data.condition else None,
'xx': input_data.condition == SetCondition.XX if input_data.condition else None,
"name": input_data.key,
"value": input_data.value,
"px": input_data.expiration_ms,
"nx": (
input_data.condition == SetCondition.NX
if input_data.condition
else None
),
"xx": (
input_data.condition == SetCondition.XX
if input_data.condition
else None
),
}
# Remove None values as redis-py expects keyword args to be present or absent
set_args_filtered = {k: v for k, v in set_args.items() if v is not None}
@@ -200,17 +228,25 @@ class RedisSetBlock(Block):
yield "key_was_set", None
yield "error", str(e)
class RedisDeleteBlock(Block):
"""Removes one or more specified keys and their associated values."""
class Input(BlockSchema):
credentials: RedisCredentialsInput = RedisCredentialsField()
host: str = SchemaField(description="Redis server host address")
port: int = SchemaField(description="Redis server port", default=6379, advanced=False)
port: int = SchemaField(
description="Redis server port", default=6379, advanced=False
)
keys: list[str] = SchemaField(description="The key(s) to delete.")
class Output(BlockSchema):
success: bool = SchemaField(description="True if the DELETE operation was successful.")
deleted_count: int = SchemaField(description="Number of keys that were actually deleted.")
success: bool = SchemaField(
description="True if the DELETE operation was successful."
)
deleted_count: int = SchemaField(
description="Number of keys that were actually deleted."
)
error: str = SchemaField(description="Error message if operation failed.")
def __init__(self):
@@ -225,17 +261,11 @@ class RedisDeleteBlock(Block):
"credentials": TEST_REDIS_CREDENTIALS_INPUT,
"host": "localhost",
"port": 6379,
"keys": ["my_key1", "my_key2"]
"keys": ["my_key1", "my_key2"],
},
test_output=[
("success", True),
("deleted_count", 2)
],
test_output=[("success", True), ("deleted_count", 2)],
test_mock={
"run": lambda *args, **kwargs: [
("success", True),
("deleted_count", 2)
]
"run": lambda *args, **kwargs: [("success", True), ("deleted_count", 2)]
},
)
@@ -246,9 +276,17 @@ class RedisDeleteBlock(Block):
with redis.Redis(
host=input_data.host,
port=input_data.port,
username=credentials.username.get_secret_value() if credentials.username else "default",
password=credentials.password.get_secret_value() if credentials.password else None,
decode_responses=True
username=(
credentials.username.get_secret_value()
if credentials.username
else "default"
),
password=(
credentials.password.get_secret_value()
if credentials.password
else None
),
decode_responses=True,
) as r:
# Delete the specified keys
deleted_count = r.delete(*input_data.keys)
@@ -261,17 +299,25 @@ class RedisDeleteBlock(Block):
yield "deleted_count", 0
yield "error", str(e)
class RedisExistsBlock(Block):
"""Checks if one or more specified keys exist in the database."""
class Input(BlockSchema):
credentials: RedisCredentialsInput = RedisCredentialsField()
host: str = SchemaField(description="Redis server host address")
port: int = SchemaField(description="Redis server port", default=6379, advanced=False)
port: int = SchemaField(
description="Redis server port", default=6379, advanced=False
)
keys: list[str] = SchemaField(description="The key(s) to check for existence.")
class Output(BlockSchema):
success: bool = SchemaField(description="True if the EXISTS operation was successful.")
count: int = SchemaField(description="Number of keys that exist in the database.")
success: bool = SchemaField(
description="True if the EXISTS operation was successful."
)
count: int = SchemaField(
description="Number of keys that exist in the database."
)
error: str = SchemaField(description="Error message if operation failed.")
def __init__(self):
@@ -286,17 +332,11 @@ class RedisExistsBlock(Block):
"credentials": TEST_REDIS_CREDENTIALS_INPUT,
"host": "localhost",
"port": 6379,
"keys": ["my_key1", "my_key2"]
"keys": ["my_key1", "my_key2"],
},
test_output=[
("success", True),
("count", 1)
],
test_output=[("success", True), ("count", 1)],
test_mock={
"run": lambda *args, **kwargs: [
("success", True),
("count", 1)
]
"run": lambda *args, **kwargs: [("success", True), ("count", 1)]
},
)
@@ -307,9 +347,17 @@ class RedisExistsBlock(Block):
with redis.Redis(
host=input_data.host,
port=input_data.port,
username=credentials.username.get_secret_value() if credentials.username else "default",
password=credentials.password.get_secret_value() if credentials.password else None,
decode_responses=True
username=(
credentials.username.get_secret_value()
if credentials.username
else "default"
),
password=(
credentials.password.get_secret_value()
if credentials.password
else None
),
decode_responses=True,
) as r:
# Check if the specified keys exist
count = r.exists(*input_data.keys)
@@ -322,23 +370,32 @@ class RedisExistsBlock(Block):
yield "count", 0
yield "error", str(e)
class RedisAtomicCounterBlock(Block):
"""Atomically increases or decreases the integer value stored at a key."""
class Input(BlockSchema):
credentials: RedisCredentialsInput = RedisCredentialsField()
host: str = SchemaField(description="Redis server host address")
port: int = SchemaField(description="Redis server port", default=6379, advanced=False)
port: int = SchemaField(
description="Redis server port", default=6379, advanced=False
)
key: str = SchemaField(description="The key storing the counter value.")
increment: int = SchemaField(description="Amount to increment (positive) or decrement (negative).", default=1)
increment: int = SchemaField(
description="Amount to increment (positive) or decrement (negative).",
default=1,
)
initial_value: Optional[int] = SchemaField(
description="Initial value if key doesn't exist yet.",
default=0,
advanced=True
advanced=True,
)
class Output(BlockSchema):
success: bool = SchemaField(description="True if the operation was successful.")
new_value: int = SchemaField(description="The new value after the increment/decrement operation.")
new_value: int = SchemaField(
description="The new value after the increment/decrement operation."
)
error: str = SchemaField(description="Error message if operation failed.")
def __init__(self):
@@ -355,17 +412,11 @@ class RedisAtomicCounterBlock(Block):
"port": 6379,
"key": "my_counter",
"increment": 5,
"initial_value": 0
"initial_value": 0,
},
test_output=[
("success", True),
("new_value", 5)
],
test_output=[("success", True), ("new_value", 5)],
test_mock={
"run": lambda *args, **kwargs: [
("success", True),
("new_value", 5)
]
"run": lambda *args, **kwargs: [("success", True), ("new_value", 5)]
},
)
@@ -376,9 +427,17 @@ class RedisAtomicCounterBlock(Block):
with redis.Redis(
host=input_data.host,
port=input_data.port,
username=credentials.username.get_secret_value() if credentials.username else "default",
password=credentials.password.get_secret_value() if credentials.password else None,
decode_responses=True
username=(
credentials.username.get_secret_value()
if credentials.username
else "default"
),
password=(
credentials.password.get_secret_value()
if credentials.password
else None
),
decode_responses=True,
) as r:
if not r.exists(input_data.key) and input_data.initial_value != 0:
r.set(input_data.key, str(input_data.initial_value))
@@ -397,20 +456,26 @@ class RedisAtomicCounterBlock(Block):
yield "new_value", 0
yield "error", str(e)
class RedisInfoBlock(Block):
"""Retrieves information and statistics about the Redis server instance."""
class Input(BlockSchema):
credentials: RedisCredentialsInput = RedisCredentialsField()
host: str = SchemaField(description="Redis server host address")
port: int = SchemaField(description="Redis server port", default=6379, advanced=False)
port: int = SchemaField(
description="Redis server port", default=6379, advanced=False
)
section: Optional[str] = SchemaField(
description="Optional section of information to retrieve (e.g., 'server', 'clients', 'memory'). If not provided, all sections are returned.",
default=None
default=None,
)
class Output(BlockSchema):
success: bool = SchemaField(description="True if the operation was successful.")
info: dict = SchemaField(description="Dictionary containing server information and statistics.")
info: dict = SchemaField(
description="Dictionary containing server information and statistics."
)
error: str = SchemaField(description="Error message if operation failed.")
def __init__(self):
@@ -425,16 +490,16 @@ class RedisInfoBlock(Block):
"credentials": TEST_REDIS_CREDENTIALS_INPUT,
"host": "localhost",
"port": 6379,
"section": None
"section": None,
},
test_output=[
("success", True),
("info", {"redis_version": "6.2.6", "uptime_in_seconds": "3600"})
("info", {"redis_version": "6.2.6", "uptime_in_seconds": "3600"}),
],
test_mock={
"run": lambda *args, **kwargs: [
("success", True),
("info", {"redis_version": "6.2.6", "uptime_in_seconds": "3600"})
("info", {"redis_version": "6.2.6", "uptime_in_seconds": "3600"}),
]
},
)
@@ -446,9 +511,17 @@ class RedisInfoBlock(Block):
with redis.Redis(
host=input_data.host,
port=input_data.port,
username=credentials.username.get_secret_value() if credentials.username else "default",
password=credentials.password.get_secret_value() if credentials.password else None,
decode_responses=True
username=(
credentials.username.get_secret_value()
if credentials.username
else "default"
),
password=(
credentials.password.get_secret_value()
if credentials.password
else None
),
decode_responses=True,
) as r:
info = r.info(section=input_data.section)
@@ -461,22 +534,28 @@ class RedisInfoBlock(Block):
yield "info", {}
yield "error", str(e)
class RedisListPushBlock(Block):
"""Adds one or more elements to the beginning or end of a list."""
class Input(BlockSchema):
credentials: RedisCredentialsInput = RedisCredentialsField()
host: str = SchemaField(description="Redis server host address")
port: int = SchemaField(description="Redis server port", default=6379, advanced=False)
port: int = SchemaField(
description="Redis server port", default=6379, advanced=False
)
key: str = SchemaField(description="The key of the list to push to.")
values: list[str] = SchemaField(description="The value(s) to push to the list.")
direction: ListDirection = SchemaField(
description="Direction to push: LEFT (beginning) or RIGHT (end).",
default=ListDirection.RIGHT
default=ListDirection.RIGHT,
)
class Output(BlockSchema):
success: bool = SchemaField(description="True if the operation was successful.")
new_length: int = SchemaField(description="The new length of the list after the push operation.")
new_length: int = SchemaField(
description="The new length of the list after the push operation."
)
error: str = SchemaField(description="Error message if operation failed.")
def __init__(self):
@@ -493,17 +572,11 @@ class RedisListPushBlock(Block):
"port": 6379,
"key": "my_list",
"values": ["value1", "value2"],
"direction": ListDirection.RIGHT
"direction": ListDirection.RIGHT,
},
test_output=[
("success", True),
("new_length", 2)
],
test_output=[("success", True), ("new_length", 2)],
test_mock={
"run": lambda *args, **kwargs: [
("success", True),
("new_length", 2)
]
"run": lambda *args, **kwargs: [("success", True), ("new_length", 2)]
},
)
@@ -514,9 +587,17 @@ class RedisListPushBlock(Block):
with redis.Redis(
host=input_data.host,
port=input_data.port,
username=credentials.username.get_secret_value() if credentials.username else "default",
password=credentials.password.get_secret_value() if credentials.password else None,
decode_responses=True
username=(
credentials.username.get_secret_value()
if credentials.username
else "default"
),
password=(
credentials.password.get_secret_value()
if credentials.password
else None
),
decode_responses=True,
) as r:
# Choose push method based on direction
if input_data.direction == ListDirection.LEFT:
@@ -532,26 +613,32 @@ class RedisListPushBlock(Block):
yield "new_length", 0
yield "error", str(e)
class RedisListPopBlock(Block):
"""Removes and returns an element from the beginning or end of a list, optionally waiting if empty."""
class Input(BlockSchema):
credentials: RedisCredentialsInput = RedisCredentialsField()
host: str = SchemaField(description="Redis server host address")
port: int = SchemaField(description="Redis server port", default=6379, advanced=False)
port: int = SchemaField(
description="Redis server port", default=6379, advanced=False
)
key: str = SchemaField(description="The key of the list to pop from.")
direction: ListDirection = SchemaField(
description="Direction to pop from: LEFT (beginning) or RIGHT (end).",
default=ListDirection.LEFT
default=ListDirection.LEFT,
)
wait_ms: Optional[int] = SchemaField(
description="Time to wait in milliseconds if the list is empty (0 means no wait).",
default=0,
advanced=True
advanced=True,
)
class Output(BlockSchema):
success: bool = SchemaField(description="True if the operation was successful.")
value: Optional[str] = SchemaField(description="The popped element, or None if list is empty and not waiting.")
value: Optional[str] = SchemaField(
description="The popped element, or None if list is empty and not waiting."
)
error: str = SchemaField(description="Error message if operation failed.")
def __init__(self):
@@ -568,17 +655,11 @@ class RedisListPopBlock(Block):
"port": 6379,
"key": "my_list",
"direction": ListDirection.LEFT,
"wait_ms": 0
"wait_ms": 0,
},
test_output=[
("success", True),
("value", "item1")
],
test_output=[("success", True), ("value", "item1")],
test_mock={
"run": lambda *args, **kwargs: [
("success", True),
("value", "item1")
]
"run": lambda *args, **kwargs: [("success", True), ("value", "item1")]
},
)
@@ -589,17 +670,33 @@ class RedisListPopBlock(Block):
with redis.Redis(
host=input_data.host,
port=input_data.port,
username=credentials.username.get_secret_value() if credentials.username else "default",
password=credentials.password.get_secret_value() if credentials.password else None,
decode_responses=True
username=(
credentials.username.get_secret_value()
if credentials.username
else "default"
),
password=(
credentials.password.get_secret_value()
if credentials.password
else None
),
decode_responses=True,
) as r:
if input_data.wait_ms and input_data.wait_ms > 0:
if input_data.direction == ListDirection.LEFT:
result = r.blpop([input_data.key], timeout=int(input_data.wait_ms/1000))
result = r.blpop(
[input_data.key], timeout=int(input_data.wait_ms / 1000)
)
else:
result = r.brpop([input_data.key], timeout=int(input_data.wait_ms/1000))
result = r.brpop(
[input_data.key], timeout=int(input_data.wait_ms / 1000)
)
value = result[1] if result and isinstance(result, (list, tuple)) else None
value = (
result[1]
if result and isinstance(result, (list, tuple))
else None
)
else:
if input_data.direction == ListDirection.LEFT:
value = r.lpop(input_data.key)
@@ -614,19 +711,30 @@ class RedisListPopBlock(Block):
yield "value", None
yield "error", str(e)
class RedisListGetBlock(Block):
"""Retrieves elements from a list stored at a key."""
class Input(BlockSchema):
credentials: RedisCredentialsInput = RedisCredentialsField()
host: str = SchemaField(description="Redis server host address")
port: int = SchemaField(description="Redis server port", default=6379, advanced=False)
port: int = SchemaField(
description="Redis server port", default=6379, advanced=False
)
key: str = SchemaField(description="The key of the list to get elements from.")
start: int = SchemaField(description="The starting index (0-based, inclusive).", default=0)
end: int = SchemaField(description="The ending index (inclusive). Use -1 for all elements to the end.", default=-1)
start: int = SchemaField(
description="The starting index (0-based, inclusive).", default=0
)
end: int = SchemaField(
description="The ending index (inclusive). Use -1 for all elements to the end.",
default=-1,
)
class Output(BlockSchema):
success: bool = SchemaField(description="True if the operation was successful.")
values: list[str] = SchemaField(description="The list elements in the specified range.")
values: list[str] = SchemaField(
description="The list elements in the specified range."
)
error: str = SchemaField(description="Error message if operation failed.")
def __init__(self):
@@ -643,16 +751,13 @@ class RedisListGetBlock(Block):
"port": 6379,
"key": "my_list",
"start": 0,
"end": -1
"end": -1,
},
test_output=[
("success", True),
("values", ["item1", "item2", "item3"])
],
test_output=[("success", True), ("values", ["item1", "item2", "item3"])],
test_mock={
"run": lambda *args, **kwargs: [
("success", True),
("values", ["item1", "item2", "item3"])
("values", ["item1", "item2", "item3"]),
]
},
)
@@ -664,9 +769,17 @@ class RedisListGetBlock(Block):
with redis.Redis(
host=input_data.host,
port=input_data.port,
username=credentials.username.get_secret_value() if credentials.username else "default",
password=credentials.password.get_secret_value() if credentials.password else None,
decode_responses=True
username=(
credentials.username.get_secret_value()
if credentials.username
else "default"
),
password=(
credentials.password.get_secret_value()
if credentials.password
else None
),
decode_responses=True,
) as r:
# Get list elements in the specified range
values = r.lrange(input_data.key, input_data.start, input_data.end)
@@ -682,16 +795,21 @@ class RedisListGetBlock(Block):
class RedisPublishBlock(Block):
"""Sends (publishes) a message to a specific communication channel."""
class Input(BlockSchema):
credentials: RedisCredentialsInput = RedisCredentialsField()
host: str = SchemaField(description="Redis server host address")
port: int = SchemaField(description="Redis server port", default=6379, advanced=False)
port: int = SchemaField(
description="Redis server port", default=6379, advanced=False
)
channel: str = SchemaField(description="The channel to publish the message to.")
message: str = SchemaField(description="The message to publish.")
class Output(BlockSchema):
success: bool = SchemaField(description="True if the operation was successful.")
receivers: int = SchemaField(description="Number of clients that received the message.")
receivers: int = SchemaField(
description="Number of clients that received the message."
)
error: str = SchemaField(description="Error message if operation failed.")
def __init__(self):
@@ -707,17 +825,11 @@ class RedisPublishBlock(Block):
"host": "localhost",
"port": 6379,
"channel": "my_channel",
"message": "Hello Redis!"
"message": "Hello Redis!",
},
test_output=[
("success", True),
("receivers", 1)
],
test_output=[("success", True), ("receivers", 1)],
test_mock={
"run": lambda *args, **kwargs: [
("success", True),
("receivers", 1)
]
"run": lambda *args, **kwargs: [("success", True), ("receivers", 1)]
},
)
@@ -728,9 +840,17 @@ class RedisPublishBlock(Block):
with redis.Redis(
host=input_data.host,
port=input_data.port,
username=credentials.username.get_secret_value() if credentials.username else "default",
password=credentials.password.get_secret_value() if credentials.password else None,
decode_responses=True
username=(
credentials.username.get_secret_value()
if credentials.username
else "default"
),
password=(
credentials.password.get_secret_value()
if credentials.password
else None
),
decode_responses=True,
) as r:
# Publish message to the specified channel
receivers = r.publish(input_data.channel, input_data.message)