mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-04-08 03:00:28 -04:00
fix(backend): Convert generic exceptions to appropriate typed exceptions (#11641)
## Summary - Fix TimeoutError in AIShortformVideoCreatorBlock → BlockExecutionError - Fix generic exceptions in SearchTheWebBlock → BlockExecutionError with proper HTTP error handling - Fix FirecrawlError 504 timeouts → BlockExecutionError with service-specific messages - Fix ReplicateBlock validation errors → BlockInputError for 422 status, BlockExecutionError for others - Add comprehensive HTTP error handling with HTTPClientError/HTTPServerError classes - Implement filename sanitization for "File name too long" errors - Add proper User-Agent handling for Wikipedia API compliance - Fix type conversion for string subclasses like ShortTextType - Add support for moderation errors with proper context propagation ## Test plan - [x] All modified blocks now properly categorize errors instead of raising BlockUnknownError - [x] Type conversion tests pass for ShortTextType and other string subclasses - [x] Formatting and linting pass - [x] Exception constructors include required block_name and block_id parameters 🤖 Generated with [Claude Code](https://claude.ai/code) --------- Co-authored-by: Claude <noreply@anthropic.com>
This commit is contained in:
@@ -20,6 +20,7 @@ from backend.data.model import (
|
||||
SchemaField,
|
||||
)
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.util.exceptions import BlockExecutionError
|
||||
from backend.util.request import Requests
|
||||
|
||||
TEST_CREDENTIALS = APIKeyCredentials(
|
||||
@@ -246,7 +247,11 @@ class AIShortformVideoCreatorBlock(Block):
|
||||
await asyncio.sleep(10)
|
||||
|
||||
logger.error("Video creation timed out")
|
||||
raise TimeoutError("Video creation timed out")
|
||||
raise BlockExecutionError(
|
||||
message="Video creation timed out",
|
||||
block_name=self.name,
|
||||
block_id=self.id,
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
@@ -422,7 +427,11 @@ class AIAdMakerVideoCreatorBlock(Block):
|
||||
await asyncio.sleep(10)
|
||||
|
||||
logger.error("Video creation timed out")
|
||||
raise TimeoutError("Video creation timed out")
|
||||
raise BlockExecutionError(
|
||||
message="Video creation timed out",
|
||||
block_name=self.name,
|
||||
block_id=self.id,
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
@@ -599,7 +608,11 @@ class AIScreenshotToVideoAdBlock(Block):
|
||||
await asyncio.sleep(10)
|
||||
|
||||
logger.error("Video creation timed out")
|
||||
raise TimeoutError("Video creation timed out")
|
||||
raise BlockExecutionError(
|
||||
message="Video creation timed out",
|
||||
block_name=self.name,
|
||||
block_id=self.id,
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
|
||||
@@ -106,7 +106,10 @@ class ConditionBlock(Block):
|
||||
ComparisonOperator.LESS_THAN_OR_EQUAL: lambda a, b: a <= b,
|
||||
}
|
||||
|
||||
result = comparison_funcs[operator](value1, value2)
|
||||
try:
|
||||
result = comparison_funcs[operator](value1, value2)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Comparison failed: {e}") from e
|
||||
|
||||
yield "result", result
|
||||
|
||||
|
||||
@@ -15,6 +15,7 @@ from backend.sdk import (
|
||||
SchemaField,
|
||||
cost,
|
||||
)
|
||||
from backend.util.exceptions import BlockExecutionError
|
||||
|
||||
from ._config import firecrawl
|
||||
|
||||
@@ -59,11 +60,18 @@ class FirecrawlExtractBlock(Block):
|
||||
) -> BlockOutput:
|
||||
app = FirecrawlApp(api_key=credentials.api_key.get_secret_value())
|
||||
|
||||
extract_result = app.extract(
|
||||
urls=input_data.urls,
|
||||
prompt=input_data.prompt,
|
||||
schema=input_data.output_schema,
|
||||
enable_web_search=input_data.enable_web_search,
|
||||
)
|
||||
try:
|
||||
extract_result = app.extract(
|
||||
urls=input_data.urls,
|
||||
prompt=input_data.prompt,
|
||||
schema=input_data.output_schema,
|
||||
enable_web_search=input_data.enable_web_search,
|
||||
)
|
||||
except Exception as e:
|
||||
raise BlockExecutionError(
|
||||
message=f"Extract failed: {e}",
|
||||
block_name=self.name,
|
||||
block_id=self.id,
|
||||
) from e
|
||||
|
||||
yield "data", extract_result.data
|
||||
|
||||
@@ -19,6 +19,7 @@ from backend.data.model import (
|
||||
SchemaField,
|
||||
)
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.util.exceptions import ModerationError
|
||||
from backend.util.file import MediaFileType, store_media_file
|
||||
|
||||
TEST_CREDENTIALS = APIKeyCredentials(
|
||||
@@ -153,6 +154,8 @@ class AIImageEditorBlock(Block):
|
||||
),
|
||||
aspect_ratio=input_data.aspect_ratio.value,
|
||||
seed=input_data.seed,
|
||||
user_id=user_id,
|
||||
graph_exec_id=graph_exec_id,
|
||||
)
|
||||
yield "output_image", result
|
||||
|
||||
@@ -164,6 +167,8 @@ class AIImageEditorBlock(Block):
|
||||
input_image_b64: Optional[str],
|
||||
aspect_ratio: str,
|
||||
seed: Optional[int],
|
||||
user_id: str,
|
||||
graph_exec_id: str,
|
||||
) -> MediaFileType:
|
||||
client = ReplicateClient(api_token=api_key.get_secret_value())
|
||||
input_params = {
|
||||
@@ -173,11 +178,21 @@ class AIImageEditorBlock(Block):
|
||||
**({"seed": seed} if seed is not None else {}),
|
||||
}
|
||||
|
||||
output: FileOutput | list[FileOutput] = await client.async_run( # type: ignore
|
||||
model_name,
|
||||
input=input_params,
|
||||
wait=False,
|
||||
)
|
||||
try:
|
||||
output: FileOutput | list[FileOutput] = await client.async_run( # type: ignore
|
||||
model_name,
|
||||
input=input_params,
|
||||
wait=False,
|
||||
)
|
||||
except Exception as e:
|
||||
if "flagged as sensitive" in str(e).lower():
|
||||
raise ModerationError(
|
||||
message="Content was flagged as sensitive by the model provider",
|
||||
user_id=user_id,
|
||||
graph_exec_id=graph_exec_id,
|
||||
moderation_type="model_provider",
|
||||
)
|
||||
raise ValueError(f"Model execution failed: {e}") from e
|
||||
|
||||
if isinstance(output, list) and output:
|
||||
output = output[0]
|
||||
|
||||
@@ -2,7 +2,6 @@ from enum import Enum
|
||||
from typing import Any, Dict, Literal, Optional
|
||||
|
||||
from pydantic import SecretStr
|
||||
from requests.exceptions import RequestException
|
||||
|
||||
from backend.data.block import (
|
||||
Block,
|
||||
@@ -332,8 +331,8 @@ class IdeogramModelBlock(Block):
|
||||
try:
|
||||
response = await Requests().post(url, headers=headers, json=data)
|
||||
return response.json()["data"][0]["url"]
|
||||
except RequestException as e:
|
||||
raise Exception(f"Failed to fetch image with V3 endpoint: {str(e)}")
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to fetch image with V3 endpoint: {e}") from e
|
||||
|
||||
async def _run_model_legacy(
|
||||
self,
|
||||
@@ -385,8 +384,8 @@ class IdeogramModelBlock(Block):
|
||||
try:
|
||||
response = await Requests().post(url, headers=headers, json=data)
|
||||
return response.json()["data"][0]["url"]
|
||||
except RequestException as e:
|
||||
raise Exception(f"Failed to fetch image with legacy endpoint: {str(e)}")
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to fetch image with legacy endpoint: {e}") from e
|
||||
|
||||
async def upscale_image(self, api_key: SecretStr, image_url: str):
|
||||
url = "https://api.ideogram.ai/upscale"
|
||||
@@ -413,5 +412,5 @@ class IdeogramModelBlock(Block):
|
||||
|
||||
return (response.json())["data"][0]["url"]
|
||||
|
||||
except RequestException as e:
|
||||
raise Exception(f"Failed to upscale image: {str(e)}")
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to upscale image: {e}") from e
|
||||
|
||||
@@ -16,6 +16,7 @@ from backend.data.block import (
|
||||
BlockSchemaOutput,
|
||||
)
|
||||
from backend.data.model import SchemaField
|
||||
from backend.util.exceptions import BlockExecutionError
|
||||
|
||||
|
||||
class SearchTheWebBlock(Block, GetRequest):
|
||||
@@ -56,7 +57,17 @@ class SearchTheWebBlock(Block, GetRequest):
|
||||
|
||||
# Prepend the Jina Search URL to the encoded query
|
||||
jina_search_url = f"https://s.jina.ai/{encoded_query}"
|
||||
results = await self.get_request(jina_search_url, headers=headers, json=False)
|
||||
|
||||
try:
|
||||
results = await self.get_request(
|
||||
jina_search_url, headers=headers, json=False
|
||||
)
|
||||
except Exception as e:
|
||||
raise BlockExecutionError(
|
||||
message=f"Search failed: {e}",
|
||||
block_name=self.name,
|
||||
block_id=self.id,
|
||||
) from e
|
||||
|
||||
# Output the search results
|
||||
yield "results", results
|
||||
|
||||
@@ -18,6 +18,7 @@ from backend.data.block import (
|
||||
BlockSchemaOutput,
|
||||
)
|
||||
from backend.data.model import APIKeyCredentials, CredentialsField, SchemaField
|
||||
from backend.util.exceptions import BlockExecutionError, BlockInputError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -111,9 +112,27 @@ class ReplicateModelBlock(Block):
|
||||
yield "status", "succeeded"
|
||||
yield "model_name", input_data.model_name
|
||||
except Exception as e:
|
||||
error_msg = f"Unexpected error running Replicate model: {str(e)}"
|
||||
logger.error(error_msg)
|
||||
raise RuntimeError(error_msg)
|
||||
error_msg = str(e)
|
||||
logger.error(f"Error running Replicate model: {error_msg}")
|
||||
|
||||
# Input validation errors (422, 400) → BlockInputError
|
||||
if (
|
||||
"422" in error_msg
|
||||
or "Input validation failed" in error_msg
|
||||
or "400" in error_msg
|
||||
):
|
||||
raise BlockInputError(
|
||||
message=f"Invalid model inputs: {error_msg}",
|
||||
block_name=self.name,
|
||||
block_id=self.id,
|
||||
) from e
|
||||
# Everything else → BlockExecutionError
|
||||
else:
|
||||
raise BlockExecutionError(
|
||||
message=f"Replicate model error: {error_msg}",
|
||||
block_name=self.name,
|
||||
block_id=self.id,
|
||||
) from e
|
||||
|
||||
async def run_model(self, model_ref: str, model_inputs: dict, api_key: SecretStr):
|
||||
"""
|
||||
|
||||
@@ -45,10 +45,16 @@ class GetWikipediaSummaryBlock(Block, GetRequest):
|
||||
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
topic = input_data.topic
|
||||
url = f"https://en.wikipedia.org/api/rest_v1/page/summary/{topic}"
|
||||
response = await self.get_request(url, json=True)
|
||||
if "extract" not in response:
|
||||
raise RuntimeError(f"Unable to parse Wikipedia response: {response}")
|
||||
yield "summary", response["extract"]
|
||||
|
||||
# Note: User-Agent is now automatically set by the request library
|
||||
# to comply with Wikimedia's robot policy (https://w.wiki/4wJS)
|
||||
try:
|
||||
response = await self.get_request(url, json=True)
|
||||
if "extract" not in response:
|
||||
raise ValueError(f"Unable to parse Wikipedia response: {response}")
|
||||
yield "summary", response["extract"]
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to fetch Wikipedia summary: {e}") from e
|
||||
|
||||
|
||||
TEST_CREDENTIALS = APIKeyCredentials(
|
||||
|
||||
@@ -14,12 +14,47 @@ from backend.util.virus_scanner import scan_content_safe
|
||||
|
||||
TEMP_DIR = Path(tempfile.gettempdir()).resolve()
|
||||
|
||||
# Maximum filename length (conservative limit for most filesystems)
|
||||
MAX_FILENAME_LENGTH = 200
|
||||
|
||||
|
||||
def sanitize_filename(filename: str) -> str:
|
||||
"""
|
||||
Sanitize and truncate filename to prevent filesystem errors.
|
||||
"""
|
||||
# Remove or replace invalid characters
|
||||
sanitized = re.sub(r'[<>:"/\\|?*\n\r\t]', "_", filename)
|
||||
|
||||
# Truncate if too long
|
||||
if len(sanitized) > MAX_FILENAME_LENGTH:
|
||||
# Keep the extension if possible
|
||||
if "." in sanitized:
|
||||
name, ext = sanitized.rsplit(".", 1)
|
||||
max_name_length = MAX_FILENAME_LENGTH - len(ext) - 1
|
||||
sanitized = name[:max_name_length] + "." + ext
|
||||
else:
|
||||
sanitized = sanitized[:MAX_FILENAME_LENGTH]
|
||||
|
||||
# Ensure it's not empty or just dots
|
||||
if not sanitized or sanitized.strip(".") == "":
|
||||
sanitized = f"file_{uuid.uuid4().hex[:8]}"
|
||||
|
||||
return sanitized
|
||||
|
||||
|
||||
def get_exec_file_path(graph_exec_id: str, path: str) -> str:
|
||||
"""
|
||||
Utility to build an absolute path in the {temp}/exec_file/{exec_id}/... folder.
|
||||
"""
|
||||
return str(TEMP_DIR / "exec_file" / graph_exec_id / path)
|
||||
try:
|
||||
full_path = TEMP_DIR / "exec_file" / graph_exec_id / path
|
||||
return str(full_path)
|
||||
except OSError as e:
|
||||
if "File name too long" in str(e):
|
||||
raise ValueError(
|
||||
f"File path too long: {len(path)} characters. Maximum path length exceeded."
|
||||
) from e
|
||||
raise ValueError(f"Invalid file path: {e}") from e
|
||||
|
||||
|
||||
def clean_exec_files(graph_exec_id: str, file: str = "") -> None:
|
||||
@@ -117,8 +152,11 @@ async def store_media_file(
|
||||
|
||||
# Generate filename from cloud path
|
||||
_, path_part = cloud_storage.parse_cloud_path(file)
|
||||
filename = Path(path_part).name or f"{uuid.uuid4()}.bin"
|
||||
target_path = _ensure_inside_base(base_path / filename, base_path)
|
||||
filename = sanitize_filename(Path(path_part).name or f"{uuid.uuid4()}.bin")
|
||||
try:
|
||||
target_path = _ensure_inside_base(base_path / filename, base_path)
|
||||
except OSError as e:
|
||||
raise ValueError(f"Invalid file path '{filename}': {e}") from e
|
||||
|
||||
# Check file size limit
|
||||
if len(cloud_content) > MAX_FILE_SIZE:
|
||||
@@ -144,7 +182,10 @@ async def store_media_file(
|
||||
# Generate filename and decode
|
||||
extension = _extension_from_mime(mime_type)
|
||||
filename = f"{uuid.uuid4()}{extension}"
|
||||
target_path = _ensure_inside_base(base_path / filename, base_path)
|
||||
try:
|
||||
target_path = _ensure_inside_base(base_path / filename, base_path)
|
||||
except OSError as e:
|
||||
raise ValueError(f"Invalid file path '{filename}': {e}") from e
|
||||
content = base64.b64decode(b64_content)
|
||||
|
||||
# Check file size limit
|
||||
@@ -160,8 +201,11 @@ async def store_media_file(
|
||||
elif file.startswith(("http://", "https://")):
|
||||
# URL
|
||||
parsed_url = urlparse(file)
|
||||
filename = Path(parsed_url.path).name or f"{uuid.uuid4()}"
|
||||
target_path = _ensure_inside_base(base_path / filename, base_path)
|
||||
filename = sanitize_filename(Path(parsed_url.path).name or f"{uuid.uuid4()}")
|
||||
try:
|
||||
target_path = _ensure_inside_base(base_path / filename, base_path)
|
||||
except OSError as e:
|
||||
raise ValueError(f"Invalid file path '{filename}': {e}") from e
|
||||
|
||||
# Download and save
|
||||
resp = await Requests().get(file)
|
||||
@@ -177,8 +221,12 @@ async def store_media_file(
|
||||
target_path.write_bytes(resp.content)
|
||||
|
||||
else:
|
||||
# Local path
|
||||
target_path = _ensure_inside_base(base_path / file, base_path)
|
||||
# Local path - sanitize the filename part to prevent long filename errors
|
||||
sanitized_file = sanitize_filename(file)
|
||||
try:
|
||||
target_path = _ensure_inside_base(base_path / sanitized_file, base_path)
|
||||
except OSError as e:
|
||||
raise ValueError(f"Invalid file path '{sanitized_file}': {e}") from e
|
||||
if not target_path.is_file():
|
||||
raise ValueError(f"Local file does not exist: {target_path}")
|
||||
|
||||
|
||||
@@ -21,6 +21,26 @@ from tenacity import (
|
||||
|
||||
from backend.util.json import loads
|
||||
|
||||
|
||||
class HTTPClientError(Exception):
|
||||
"""4xx client errors (400-499)"""
|
||||
|
||||
def __init__(self, message: str, status_code: int):
|
||||
super().__init__(message)
|
||||
self.status_code = status_code
|
||||
|
||||
|
||||
class HTTPServerError(Exception):
|
||||
"""5xx server errors (500-599)"""
|
||||
|
||||
def __init__(self, message: str, status_code: int):
|
||||
super().__init__(message)
|
||||
self.status_code = status_code
|
||||
|
||||
|
||||
# Default User-Agent for all requests
|
||||
DEFAULT_USER_AGENT = "AutoGPT-Platform/1.0 (https://github.com/Significant-Gravitas/AutoGPT; info@agpt.co) aiohttp"
|
||||
|
||||
# Retry status codes for which we will automatically retry the request
|
||||
THROTTLE_RETRY_STATUS_CODES: set[int] = {429, 500, 502, 503, 504, 408}
|
||||
|
||||
@@ -450,6 +470,10 @@ class Requests:
|
||||
if self.extra_headers is not None:
|
||||
req_headers.update(self.extra_headers)
|
||||
|
||||
# Set default User-Agent if not provided
|
||||
if "User-Agent" not in req_headers and "user-agent" not in req_headers:
|
||||
req_headers["User-Agent"] = DEFAULT_USER_AGENT
|
||||
|
||||
# Override Host header if using IP connection
|
||||
if connector:
|
||||
req_headers["Host"] = hostname
|
||||
@@ -476,9 +500,16 @@ class Requests:
|
||||
response.raise_for_status()
|
||||
except ClientResponseError as e:
|
||||
body = await response.read()
|
||||
raise Exception(
|
||||
f"HTTP {response.status} Error: {response.reason}, Body: {body.decode(errors='replace')}"
|
||||
) from e
|
||||
error_message = f"HTTP {response.status} Error: {response.reason}, Body: {body.decode(errors='replace')}"
|
||||
|
||||
# Raise specific exceptions based on status code range
|
||||
if 400 <= response.status <= 499:
|
||||
raise HTTPClientError(error_message, response.status) from e
|
||||
elif 500 <= response.status <= 599:
|
||||
raise HTTPServerError(error_message, response.status) from e
|
||||
else:
|
||||
# Generic fallback for other HTTP errors
|
||||
raise Exception(error_message) from e
|
||||
|
||||
# If allowed and a redirect is received, follow the redirect manually
|
||||
if allow_redirects and response.status in (301, 302, 303, 307, 308):
|
||||
|
||||
@@ -5,6 +5,13 @@ from typing import Any, Type, TypeVar, Union, cast, get_args, get_origin, overlo
|
||||
from prisma import Json as PrismaJson
|
||||
|
||||
|
||||
def _is_type_or_subclass(origin: Any, target_type: type) -> bool:
|
||||
"""Check if origin is exactly the target type or a subclass of it."""
|
||||
return origin is target_type or (
|
||||
isinstance(origin, type) and issubclass(origin, target_type)
|
||||
)
|
||||
|
||||
|
||||
class ConversionError(ValueError):
|
||||
pass
|
||||
|
||||
@@ -138,7 +145,11 @@ def _try_convert(value: Any, target_type: Any, raise_on_mismatch: bool) -> Any:
|
||||
|
||||
if origin is None:
|
||||
origin = target_type
|
||||
if origin not in [list, dict, tuple, str, set, int, float, bool]:
|
||||
# Early return for unsupported types (skip subclasses of supported types)
|
||||
supported_types = [list, dict, tuple, str, set, int, float, bool]
|
||||
if origin not in supported_types and not (
|
||||
isinstance(origin, type) and any(issubclass(origin, t) for t in supported_types)
|
||||
):
|
||||
return value
|
||||
|
||||
# Handle the case when value is already of the target type
|
||||
@@ -168,44 +179,47 @@ def _try_convert(value: Any, target_type: Any, raise_on_mismatch: bool) -> Any:
|
||||
raise TypeError(f"Value {value} is not of expected type {target_type}")
|
||||
else:
|
||||
# Need to convert value to the origin type
|
||||
if origin is list:
|
||||
value = __convert_list(value)
|
||||
if _is_type_or_subclass(origin, list):
|
||||
converted_list = __convert_list(value)
|
||||
if args:
|
||||
return [convert(v, args[0]) for v in value]
|
||||
else:
|
||||
return value
|
||||
elif origin is dict:
|
||||
value = __convert_dict(value)
|
||||
converted_list = [convert(v, args[0]) for v in converted_list]
|
||||
return origin(converted_list) if origin is not list else converted_list
|
||||
elif _is_type_or_subclass(origin, dict):
|
||||
converted_dict = __convert_dict(value)
|
||||
if args:
|
||||
key_type, val_type = args
|
||||
return {
|
||||
convert(k, key_type): convert(v, val_type) for k, v in value.items()
|
||||
converted_dict = {
|
||||
convert(k, key_type): convert(v, val_type)
|
||||
for k, v in converted_dict.items()
|
||||
}
|
||||
else:
|
||||
return value
|
||||
elif origin is tuple:
|
||||
value = __convert_tuple(value)
|
||||
return origin(converted_dict) if origin is not dict else converted_dict
|
||||
elif _is_type_or_subclass(origin, tuple):
|
||||
converted_tuple = __convert_tuple(value)
|
||||
if args:
|
||||
if len(args) == 1:
|
||||
return tuple(convert(v, args[0]) for v in value)
|
||||
converted_tuple = tuple(
|
||||
convert(v, args[0]) for v in converted_tuple
|
||||
)
|
||||
else:
|
||||
return tuple(convert(v, t) for v, t in zip(value, args))
|
||||
else:
|
||||
return value
|
||||
elif origin is str:
|
||||
return __convert_str(value)
|
||||
elif origin is set:
|
||||
converted_tuple = tuple(
|
||||
convert(v, t) for v, t in zip(converted_tuple, args)
|
||||
)
|
||||
return origin(converted_tuple) if origin is not tuple else converted_tuple
|
||||
elif _is_type_or_subclass(origin, str):
|
||||
converted_str = __convert_str(value)
|
||||
return origin(converted_str) if origin is not str else converted_str
|
||||
elif _is_type_or_subclass(origin, set):
|
||||
value = __convert_set(value)
|
||||
if args:
|
||||
return {convert(v, args[0]) for v in value}
|
||||
else:
|
||||
return value
|
||||
elif origin is int:
|
||||
return __convert_num(value, int)
|
||||
elif origin is float:
|
||||
return __convert_num(value, float)
|
||||
elif origin is bool:
|
||||
elif _is_type_or_subclass(origin, bool):
|
||||
return __convert_bool(value)
|
||||
elif _is_type_or_subclass(origin, int):
|
||||
return __convert_num(value, int)
|
||||
elif _is_type_or_subclass(origin, float):
|
||||
return __convert_num(value, float)
|
||||
else:
|
||||
return value
|
||||
|
||||
|
||||
@@ -32,3 +32,17 @@ def test_type_conversion():
|
||||
assert convert("5", List[int]) == [5]
|
||||
assert convert("[5,4,2]", List[int]) == [5, 4, 2]
|
||||
assert convert([5, 4, 2], List[str]) == ["5", "4", "2"]
|
||||
|
||||
# Test the specific case that was failing: empty list to Optional[str]
|
||||
assert convert([], Optional[str]) == "[]"
|
||||
assert convert([], str) == "[]"
|
||||
|
||||
# Test the actual failing case: empty list to ShortTextType
|
||||
from backend.util.type import ShortTextType
|
||||
|
||||
assert convert([], Optional[ShortTextType]) == "[]"
|
||||
assert convert([], ShortTextType) == "[]"
|
||||
|
||||
# Test other empty list conversions
|
||||
assert convert([], int) == 0 # len([]) = 0
|
||||
assert convert([], Optional[int]) == 0
|
||||
|
||||
Reference in New Issue
Block a user