mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-01-10 07:38:04 -05:00
feat(block;backend): Truncate execution update payload on large data & Improve ReadSpreadsheetBlock performance (#10395)
### Changes 🏗️ This PR introduces several key improvements to message handling, block functionality, and execution reliability: - **Renamed CSV block to Spreadsheet block** with enhanced CSV/Excel processing capabilities - **Added message size limiting and truncation** for Redis communication to prevent connection issues - **Optimized FileReadBlock** to yield content chunks instead of duplicated outputs for better performance - **Improved execution termination handling** with better timeout management and event publishing - **Enhanced continuous retry decorator** with async function support - **Implemented payload truncation** to prevent Redis connection issues from oversized messages ### Checklist 📋 #### For code changes: - [x] I have clearly listed my changes in the PR description - [x] I have made a test plan - [x] I have tested my changes according to the test plan: - [x] Verified backend starts without errors - [x] Confirmed message truncation works for large payloads - [x] Tested spreadsheet block functionality with CSV and Excel files - [x] Validated execution termination improvements - [x] Checked FileReadBlock chunk processing #### For configuration changes: - [x] `.env.example` is updated or already compatible with my changes - [x] `docker-compose.yml` is updated or already compatible with my changes - [x] I have included a list of my configuration changes in the PR description (under **Changes**) 🤖 Generated with [Claude Code](https://claude.ai/code) --------- Co-authored-by: Claude <noreply@anthropic.com>
This commit is contained in:
@@ -12,10 +12,12 @@ class ReadSpreadsheetBlock(Block):
|
||||
description="The contents of the CSV/spreadsheet data to read",
|
||||
placeholder="a, b, c\n1,2,3\n4,5,6",
|
||||
default=None,
|
||||
advanced=False,
|
||||
)
|
||||
file_input: MediaFileType | None = SchemaField(
|
||||
description="CSV or Excel file to read from (URL, data URI, or local path). Excel files are automatically converted to CSV",
|
||||
default=None,
|
||||
advanced=False,
|
||||
)
|
||||
delimiter: str = SchemaField(
|
||||
description="The delimiter used in the CSV/spreadsheet data",
|
||||
@@ -45,6 +47,10 @@ class ReadSpreadsheetBlock(Block):
|
||||
description="The columns to skip from the start of the row",
|
||||
default_factory=list,
|
||||
)
|
||||
produce_singular_result: bool = SchemaField(
|
||||
description="If True, yield individual 'row' outputs only (can be slow). If False, yield both 'rows' (all data)",
|
||||
default=False,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
row: dict[str, str] = SchemaField(
|
||||
@@ -62,9 +68,16 @@ class ReadSpreadsheetBlock(Block):
|
||||
description="Reads CSV and Excel files and outputs the data as a list of dictionaries and individual rows. Excel files are automatically converted to CSV format.",
|
||||
contributors=[ContributorDetails(name="Nicholas Tindle")],
|
||||
categories={BlockCategory.TEXT, BlockCategory.DATA},
|
||||
test_input={
|
||||
"contents": "a, b, c\n1,2,3\n4,5,6",
|
||||
},
|
||||
test_input=[
|
||||
{
|
||||
"contents": "a, b, c\n1,2,3\n4,5,6",
|
||||
"produce_singular_result": False,
|
||||
},
|
||||
{
|
||||
"contents": "a, b, c\n1,2,3\n4,5,6",
|
||||
"produce_singular_result": True,
|
||||
},
|
||||
],
|
||||
test_output=[
|
||||
(
|
||||
"rows",
|
||||
@@ -159,6 +172,8 @@ class ReadSpreadsheetBlock(Block):
|
||||
|
||||
rows = [process_row(row) for row in reader]
|
||||
|
||||
yield "rows", rows
|
||||
for processed_row in rows:
|
||||
yield "row", processed_row
|
||||
if input_data.produce_singular_result:
|
||||
for processed_row in rows:
|
||||
yield "row", processed_row
|
||||
else:
|
||||
yield "rows", rows
|
||||
@@ -341,9 +341,8 @@ class FileReadBlock(Block):
|
||||
|
||||
class Output(BlockSchema):
|
||||
content: str = SchemaField(
|
||||
description="The full content of the file or a chunk based on delimiter/limits"
|
||||
description="File content, yielded as individual chunks when delimiter or size limits are applied"
|
||||
)
|
||||
chunk: str = SchemaField(description="Individual chunks when delimiter is used")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
@@ -418,19 +417,8 @@ class FileReadBlock(Block):
|
||||
chunks.append(chunk)
|
||||
return chunks
|
||||
|
||||
# Process items and yield chunks
|
||||
all_chunks = []
|
||||
for item in items:
|
||||
if item: # Only process non-empty items
|
||||
chunks = create_chunks(item, input_data.size_limit)
|
||||
# Only yield as 'chunk' if we have a delimiter (multiple items)
|
||||
if input_data.delimiter:
|
||||
for chunk in chunks:
|
||||
yield "chunk", chunk
|
||||
all_chunks.extend(chunks)
|
||||
|
||||
# Yield the processed content
|
||||
if all_chunks:
|
||||
# Process items and yield as content chunks
|
||||
if items:
|
||||
full_content = (
|
||||
input_data.delimiter.join(items)
|
||||
if input_data.delimiter
|
||||
|
||||
@@ -8,8 +8,11 @@ from redis.asyncio.client import PubSub as AsyncPubSub
|
||||
from redis.client import PubSub
|
||||
|
||||
from backend.data import redis_client as redis
|
||||
from backend.util import json
|
||||
from backend.util.settings import Settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
config = Settings().config
|
||||
|
||||
|
||||
M = TypeVar("M", bound=BaseModel)
|
||||
@@ -28,7 +31,41 @@ class BaseRedisEventBus(Generic[M], ABC):
|
||||
return _EventPayloadWrapper[self.Model]
|
||||
|
||||
def _serialize_message(self, item: M, channel_key: str) -> tuple[str, str]:
|
||||
message = self.Message(payload=item).model_dump_json()
|
||||
MAX_MESSAGE_SIZE = config.max_message_size_limit
|
||||
|
||||
try:
|
||||
# Use backend.util.json.dumps which handles datetime and other complex types
|
||||
message = json.dumps(
|
||||
self.Message(payload=item), ensure_ascii=False, separators=(",", ":")
|
||||
)
|
||||
except UnicodeError:
|
||||
# Fallback to ASCII encoding if Unicode causes issues
|
||||
message = json.dumps(
|
||||
self.Message(payload=item), ensure_ascii=True, separators=(",", ":")
|
||||
)
|
||||
logger.warning(
|
||||
f"Unicode serialization failed, falling back to ASCII for channel {channel_key}"
|
||||
)
|
||||
|
||||
# Check message size and truncate if necessary
|
||||
message_size = len(message.encode("utf-8"))
|
||||
if message_size > MAX_MESSAGE_SIZE:
|
||||
logger.warning(
|
||||
f"Message size {message_size} bytes exceeds limit {MAX_MESSAGE_SIZE} bytes for channel {channel_key}. "
|
||||
"Truncating payload to prevent Redis connection issues."
|
||||
)
|
||||
error_payload = {
|
||||
"payload": {
|
||||
"event_type": "error_comms_update",
|
||||
"error": "Payload too large for Redis transmission",
|
||||
"original_size_bytes": message_size,
|
||||
"max_size_bytes": MAX_MESSAGE_SIZE,
|
||||
}
|
||||
}
|
||||
message = json.dumps(
|
||||
error_payload, ensure_ascii=False, separators=(",", ":")
|
||||
)
|
||||
|
||||
channel_name = f"{self.event_bus_name}/{channel_key}"
|
||||
logger.debug(f"[{channel_name}] Publishing an event to Redis {message}")
|
||||
return message, channel_name
|
||||
|
||||
@@ -40,6 +40,7 @@ from pydantic.fields import Field
|
||||
from backend.server.v2.store.exceptions import DatabaseError
|
||||
from backend.util import type as type_utils
|
||||
from backend.util.settings import Config
|
||||
from backend.util.truncate import truncate
|
||||
|
||||
from .block import (
|
||||
BlockInput,
|
||||
@@ -866,6 +867,7 @@ class ExecutionQueue(Generic[T]):
|
||||
class ExecutionEventType(str, Enum):
|
||||
GRAPH_EXEC_UPDATE = "graph_execution_update"
|
||||
NODE_EXEC_UPDATE = "node_execution_update"
|
||||
ERROR_COMMS_UPDATE = "error_comms_update"
|
||||
|
||||
|
||||
class GraphExecutionEvent(GraphExecution):
|
||||
@@ -900,11 +902,25 @@ class RedisExecutionEventBus(RedisEventBus[ExecutionEvent]):
|
||||
|
||||
def publish_node_exec_update(self, res: NodeExecutionResult):
|
||||
event = NodeExecutionEvent.model_validate(res.model_dump())
|
||||
self.publish_event(event, f"{res.user_id}/{res.graph_id}/{res.graph_exec_id}")
|
||||
self._publish(event, f"{res.user_id}/{res.graph_id}/{res.graph_exec_id}")
|
||||
|
||||
def publish_graph_exec_update(self, res: GraphExecution):
|
||||
event = GraphExecutionEvent.model_validate(res.model_dump())
|
||||
self.publish_event(event, f"{res.user_id}/{res.graph_id}/{res.id}")
|
||||
self._publish(event, f"{res.user_id}/{res.graph_id}/{res.id}")
|
||||
|
||||
def _publish(self, event: ExecutionEvent, channel: str):
|
||||
"""
|
||||
truncate inputs and outputs to avoid large payloads
|
||||
"""
|
||||
limit = config.max_message_size_limit // 2
|
||||
if isinstance(event, GraphExecutionEvent):
|
||||
event.inputs = truncate(event.inputs, limit)
|
||||
event.outputs = truncate(event.outputs, limit)
|
||||
elif isinstance(event, NodeExecutionEvent):
|
||||
event.input_data = truncate(event.input_data, limit)
|
||||
event.output_data = truncate(event.output_data, limit)
|
||||
|
||||
super().publish_event(event, channel)
|
||||
|
||||
def listen(
|
||||
self, user_id: str, graph_id: str = "*", graph_exec_id: str = "*"
|
||||
@@ -928,13 +944,30 @@ class AsyncRedisExecutionEventBus(AsyncRedisEventBus[ExecutionEvent]):
|
||||
|
||||
async def publish_node_exec_update(self, res: NodeExecutionResult):
|
||||
event = NodeExecutionEvent.model_validate(res.model_dump())
|
||||
await self.publish_event(
|
||||
event, f"{res.user_id}/{res.graph_id}/{res.graph_exec_id}"
|
||||
)
|
||||
await self._publish(event, f"{res.user_id}/{res.graph_id}/{res.graph_exec_id}")
|
||||
|
||||
async def publish_graph_exec_update(self, res: GraphExecutionMeta):
|
||||
event = GraphExecutionEvent.model_validate(res.model_dump())
|
||||
await self.publish_event(event, f"{res.user_id}/{res.graph_id}/{res.id}")
|
||||
# GraphExecutionEvent requires inputs and outputs fields that GraphExecutionMeta doesn't have
|
||||
# Add default empty values for compatibility
|
||||
event_data = res.model_dump()
|
||||
event_data.setdefault("inputs", {})
|
||||
event_data.setdefault("outputs", {})
|
||||
event = GraphExecutionEvent.model_validate(event_data)
|
||||
await self._publish(event, f"{res.user_id}/{res.graph_id}/{res.id}")
|
||||
|
||||
async def _publish(self, event: ExecutionEvent, channel: str):
|
||||
"""
|
||||
truncate inputs and outputs to avoid large payloads
|
||||
"""
|
||||
limit = config.max_message_size_limit // 2
|
||||
if isinstance(event, GraphExecutionEvent):
|
||||
event.inputs = truncate(event.inputs, limit)
|
||||
event.outputs = truncate(event.outputs, limit)
|
||||
elif isinstance(event, NodeExecutionEvent):
|
||||
event.input_data = truncate(event.input_data, limit)
|
||||
event.output_data = truncate(event.output_data, limit)
|
||||
|
||||
await super().publish_event(event, channel)
|
||||
|
||||
async def listen(
|
||||
self, user_id: str, graph_id: str = "*", graph_exec_id: str = "*"
|
||||
|
||||
@@ -685,7 +685,7 @@ async def stop_graph_execution(
|
||||
user_id: str,
|
||||
graph_exec_id: str,
|
||||
use_db_query: bool = True,
|
||||
wait_timeout: float = 60.0,
|
||||
wait_timeout: float = 15.0,
|
||||
):
|
||||
"""
|
||||
Mechanism:
|
||||
@@ -720,33 +720,58 @@ async def stop_graph_execution(
|
||||
ExecutionStatus.FAILED,
|
||||
]:
|
||||
# If graph execution is terminated/completed/failed, cancellation is complete
|
||||
await get_async_execution_event_bus().publish(graph_exec)
|
||||
return
|
||||
|
||||
elif graph_exec.status in [
|
||||
if graph_exec.status in [
|
||||
ExecutionStatus.QUEUED,
|
||||
ExecutionStatus.INCOMPLETE,
|
||||
]:
|
||||
# If the graph is still on the queue, we can prevent them from being executed
|
||||
# by setting the status to TERMINATED.
|
||||
node_execs = await db.get_node_executions(
|
||||
graph_exec_id=graph_exec_id,
|
||||
statuses=[ExecutionStatus.QUEUED, ExecutionStatus.INCOMPLETE],
|
||||
include_exec_data=False,
|
||||
)
|
||||
await db.update_node_execution_status_batch(
|
||||
break
|
||||
|
||||
if graph_exec.status == ExecutionStatus.RUNNING:
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
# Set the termination status if the graph is not stopped after the timeout.
|
||||
if graph_exec := await db.get_graph_execution_meta(
|
||||
execution_id=graph_exec_id, user_id=user_id
|
||||
):
|
||||
# If the graph is still on the queue, we can prevent them from being executed
|
||||
# by setting the status to TERMINATED.
|
||||
node_execs = await db.get_node_executions(
|
||||
graph_exec_id=graph_exec_id,
|
||||
statuses=[
|
||||
ExecutionStatus.QUEUED,
|
||||
ExecutionStatus.RUNNING,
|
||||
],
|
||||
include_exec_data=False,
|
||||
)
|
||||
|
||||
graph_exec.status = ExecutionStatus.TERMINATED
|
||||
for node_exec in node_execs:
|
||||
node_exec.status = ExecutionStatus.TERMINATED
|
||||
|
||||
await asyncio.gather(
|
||||
# Update node execution statuses
|
||||
db.update_node_execution_status_batch(
|
||||
[node_exec.node_exec_id for node_exec in node_execs],
|
||||
ExecutionStatus.TERMINATED,
|
||||
)
|
||||
await db.update_graph_execution_stats(
|
||||
),
|
||||
# Publish node execution events
|
||||
*[
|
||||
get_async_execution_event_bus().publish(node_exec)
|
||||
for node_exec in node_execs
|
||||
],
|
||||
)
|
||||
await asyncio.gather(
|
||||
# Update graph execution status
|
||||
db.update_graph_execution_stats(
|
||||
graph_exec_id=graph_exec_id,
|
||||
status=ExecutionStatus.TERMINATED,
|
||||
)
|
||||
|
||||
await asyncio.sleep(1.0)
|
||||
|
||||
raise TimeoutError(
|
||||
f"Timed out waiting for graph execution #{graph_exec_id} to terminate."
|
||||
)
|
||||
),
|
||||
# Publish graph execution event
|
||||
get_async_execution_event_bus().publish(graph_exec),
|
||||
)
|
||||
|
||||
|
||||
async def add_graph_execution(
|
||||
|
||||
@@ -39,6 +39,7 @@ from backend.data.model import Credentials
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.server.external.api import external_app
|
||||
from backend.server.middleware.security import SecurityHeadersMiddleware
|
||||
from backend.util import json
|
||||
|
||||
settings = backend.util.settings.Settings()
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -151,7 +152,7 @@ def handle_internal_http_error(status_code: int = 500, log_error: bool = True):
|
||||
|
||||
async def validation_error_handler(
|
||||
request: fastapi.Request, exc: Exception
|
||||
) -> fastapi.responses.JSONResponse:
|
||||
) -> fastapi.responses.Response:
|
||||
logger.error(
|
||||
"Validation failed for %s %s: %s. Fix the request payload and try again.",
|
||||
request.method,
|
||||
@@ -163,13 +164,19 @@ async def validation_error_handler(
|
||||
errors = exc.errors() # type: ignore[call-arg]
|
||||
else:
|
||||
errors = str(exc)
|
||||
return fastapi.responses.JSONResponse(
|
||||
|
||||
response_content = {
|
||||
"message": f"Invalid data for {request.method} {request.url.path}",
|
||||
"detail": errors,
|
||||
"hint": "Ensure the request matches the API schema.",
|
||||
}
|
||||
|
||||
content_json = json.dumps(response_content)
|
||||
|
||||
return fastapi.responses.Response(
|
||||
content=content_json,
|
||||
status_code=422,
|
||||
content={
|
||||
"message": f"Invalid data for {request.method} {request.url.path}",
|
||||
"detail": errors,
|
||||
"hint": "Ensure the request matches the API schema.",
|
||||
},
|
||||
media_type="application/json",
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -19,6 +19,7 @@ from backend.server.model import (
|
||||
WSSubscribeGraphExecutionRequest,
|
||||
WSSubscribeGraphExecutionsRequest,
|
||||
)
|
||||
from backend.util.retry import continuous_retry
|
||||
from backend.util.service import AppProcess
|
||||
from backend.util.settings import AppEnvironment, Config, Settings
|
||||
|
||||
@@ -46,18 +47,11 @@ def get_connection_manager():
|
||||
return _connection_manager
|
||||
|
||||
|
||||
@continuous_retry()
|
||||
async def event_broadcaster(manager: ConnectionManager):
|
||||
try:
|
||||
event_queue = AsyncRedisExecutionEventBus()
|
||||
async for event in event_queue.listen("*"):
|
||||
await manager.send_execution_update(event)
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
"Event broadcaster stopped due to error: %s. "
|
||||
"Verify the Redis connection and restart the service.",
|
||||
e,
|
||||
)
|
||||
raise
|
||||
event_queue = AsyncRedisExecutionEventBus()
|
||||
async for event in event_queue.listen("*"):
|
||||
await manager.send_execution_update(event)
|
||||
|
||||
|
||||
async def authenticate_websocket(websocket: WebSocket) -> str:
|
||||
|
||||
@@ -85,8 +85,10 @@ func_retry = retry(
|
||||
|
||||
def continuous_retry(*, retry_delay: float = 1.0):
|
||||
def decorator(func):
|
||||
is_coroutine = asyncio.iscoroutinefunction(func)
|
||||
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
def sync_wrapper(*args, **kwargs):
|
||||
while True:
|
||||
try:
|
||||
return func(*args, **kwargs)
|
||||
@@ -99,6 +101,20 @@ def continuous_retry(*, retry_delay: float = 1.0):
|
||||
)
|
||||
time.sleep(retry_delay)
|
||||
|
||||
return wrapper
|
||||
@wraps(func)
|
||||
async def async_wrapper(*args, **kwargs):
|
||||
while True:
|
||||
try:
|
||||
return await func(*args, **kwargs)
|
||||
except Exception as exc:
|
||||
logger.exception(
|
||||
"%s failed with %s — retrying in %.2f s",
|
||||
func.__name__,
|
||||
exc,
|
||||
retry_delay,
|
||||
)
|
||||
await asyncio.sleep(retry_delay)
|
||||
|
||||
return async_wrapper if is_coroutine else sync_wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
@@ -315,6 +315,11 @@ class Config(UpdateTrackingModel["Config"], BaseSettings):
|
||||
description="A whitelist of trusted internal endpoints for the backend to make requests to.",
|
||||
)
|
||||
|
||||
max_message_size_limit: int = Field(
|
||||
default=16 * 1024 * 1024, # 16 MB
|
||||
description="Maximum message size limit for communication with the message bus",
|
||||
)
|
||||
|
||||
backend_cors_allow_origins: List[str] = Field(default_factory=list)
|
||||
|
||||
@field_validator("backend_cors_allow_origins")
|
||||
|
||||
128
autogpt_platform/backend/backend/util/truncate.py
Normal file
128
autogpt_platform/backend/backend/util/truncate.py
Normal file
@@ -0,0 +1,128 @@
|
||||
import sys
|
||||
from typing import Any
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# String helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _truncate_string_middle(value: str, limit: int) -> str:
|
||||
"""Shorten *value* to *limit* chars by removing the **middle** portion."""
|
||||
|
||||
if len(value) <= limit:
|
||||
return value
|
||||
|
||||
head_len = max(1, limit // 2)
|
||||
tail_len = limit - head_len # ensures total == limit
|
||||
omitted = len(value) - (head_len + tail_len)
|
||||
return f"{value[:head_len]}… (omitted {omitted} chars)…{value[-tail_len:]}"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# List helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _truncate_list_middle(lst: list[Any], str_lim: int, list_lim: int) -> list[Any]:
|
||||
"""Return *lst* truncated to *list_lim* items, removing from the middle.
|
||||
|
||||
Each retained element is itself recursively truncated via
|
||||
:func:`_truncate_value` so we don’t blow the budget with long strings nested
|
||||
inside.
|
||||
"""
|
||||
|
||||
if len(lst) <= list_lim:
|
||||
return [_truncate_value(v, str_lim, list_lim) for v in lst]
|
||||
|
||||
# If the limit is very small (<3) fall back to head‑only + sentinel to avoid
|
||||
# degenerate splits.
|
||||
if list_lim < 3:
|
||||
kept = [_truncate_value(v, str_lim, list_lim) for v in lst[:list_lim]]
|
||||
kept.append(f"… (omitted {len(lst) - list_lim} items)…")
|
||||
return kept
|
||||
|
||||
head_len = list_lim // 2
|
||||
tail_len = list_lim - head_len
|
||||
|
||||
head = [_truncate_value(v, str_lim, list_lim) for v in lst[:head_len]]
|
||||
tail = [_truncate_value(v, str_lim, list_lim) for v in lst[-tail_len:]]
|
||||
|
||||
omitted = len(lst) - (head_len + tail_len)
|
||||
sentinel = f"… (omitted {omitted} items)…"
|
||||
return head + [sentinel] + tail
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Recursive truncation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _truncate_value(value: Any, str_limit: int, list_limit: int) -> Any:
|
||||
"""Recursively truncate *value* using the current per‑type limits."""
|
||||
|
||||
if isinstance(value, str):
|
||||
return _truncate_string_middle(value, str_limit)
|
||||
|
||||
if isinstance(value, list):
|
||||
return _truncate_list_middle(value, str_limit, list_limit)
|
||||
|
||||
if isinstance(value, dict):
|
||||
return {k: _truncate_value(v, str_limit, list_limit) for k, v in value.items()}
|
||||
|
||||
return value
|
||||
|
||||
|
||||
def truncate(value: Any, size_limit: int) -> Any:
|
||||
"""
|
||||
Truncate the given value (recursively) so that its string representation
|
||||
does not exceed size_limit characters. Uses binary search to find the
|
||||
largest str_limit and list_limit that fit.
|
||||
"""
|
||||
|
||||
def measure(val):
|
||||
try:
|
||||
return len(str(val))
|
||||
except Exception:
|
||||
return sys.getsizeof(val)
|
||||
|
||||
# Reasonable bounds for string and list limits
|
||||
STR_MIN, STR_MAX = 8, 2**16
|
||||
LIST_MIN, LIST_MAX = 1, 2**12
|
||||
|
||||
# Binary search for the largest str_limit and list_limit that fit
|
||||
best = None
|
||||
|
||||
# We'll search str_limit first, then list_limit, but can do both together
|
||||
# For practical purposes, do a grid search with binary search on str_limit for each list_limit
|
||||
# (since lists are usually the main source of bloat)
|
||||
# We'll do binary search on list_limit, and for each, binary search on str_limit
|
||||
|
||||
# Outer binary search on list_limit
|
||||
l_lo, l_hi = LIST_MIN, LIST_MAX
|
||||
while l_lo <= l_hi:
|
||||
l_mid = (l_lo + l_hi) // 2
|
||||
|
||||
# Inner binary search on str_limit
|
||||
s_lo, s_hi = STR_MIN, STR_MAX
|
||||
local_best = None
|
||||
while s_lo <= s_hi:
|
||||
s_mid = (s_lo + s_hi) // 2
|
||||
truncated = _truncate_value(value, s_mid, l_mid)
|
||||
size = measure(truncated)
|
||||
if size <= size_limit:
|
||||
local_best = truncated
|
||||
s_lo = s_mid + 1 # try to increase str_limit
|
||||
else:
|
||||
s_hi = s_mid - 1 # decrease str_limit
|
||||
|
||||
if local_best is not None:
|
||||
best = local_best
|
||||
l_lo = l_mid + 1 # try to increase list_limit
|
||||
else:
|
||||
l_hi = l_mid - 1 # decrease list_limit
|
||||
|
||||
# If nothing fits, fall back to the most aggressive truncation
|
||||
if best is None:
|
||||
best = _truncate_value(value, STR_MIN, LIST_MIN)
|
||||
|
||||
return best
|
||||
Reference in New Issue
Block a user