address more comments

This commit is contained in:
Reinier van der Leer
2026-02-12 15:57:35 +01:00
parent ab25516a46
commit 2a46d3fbf4
2 changed files with 64 additions and 23 deletions

View File

@@ -12,6 +12,7 @@ import uuid
from concurrent.futures import Future, ThreadPoolExecutor
from pika.adapters.blocking_connection import BlockingChannel
from pika.exceptions import AMQPChannelError, AMQPConnectionError
from pika.spec import Basic, BasicProperties
from prometheus_client import Gauge, start_http_server
@@ -21,7 +22,7 @@ from backend.executor.cluster_lock import ClusterLock
from backend.util.decorator import error_logged
from backend.util.logging import TruncatedLogger
from backend.util.process import AppProcess
from backend.util.retry import continuous_retry, func_retry
from backend.util.retry import continuous_retry
from backend.util.settings import Settings
from .processor import execute_copilot_task, init_worker
@@ -235,7 +236,6 @@ class CoPilotExecutor(AppProcess):
auto_ack=False,
consumer_tag="copilot_execution_consumer",
)
run_channel.confirm_delivery()
logger.info("Starting to consume run messages...")
run_channel.start_consuming()
if not self.stop_consuming.is_set():
@@ -278,18 +278,46 @@ class CoPilotExecutor(AppProcess):
):
"""Handle run message from DIRECT exchange."""
delivery_tag = method.delivery_tag
# Capture the channel used at message delivery time to ensure we ack
# on the correct channel. Delivery tags are channel-scoped and become
# invalid if the channel is recreated after reconnection.
delivery_channel = _channel
@func_retry
def ack_message(reject: bool, requeue: bool):
"""Acknowledge or reject the message."""
channel = self.run_client.get_channel()
if reject:
channel.connection.add_callback_threadsafe(
lambda: channel.basic_nack(delivery_tag, requeue=requeue)
"""Acknowledge or reject the message.
Uses the channel from the original message delivery. If the channel
is no longer open (e.g., after reconnection), logs a warning and
skips the ack - RabbitMQ will redeliver the message automatically.
"""
try:
if not delivery_channel.is_open:
logger.warning(
f"Channel closed, cannot ack delivery_tag={delivery_tag}. "
"Message will be redelivered by RabbitMQ."
)
return
if reject:
delivery_channel.connection.add_callback_threadsafe(
lambda: delivery_channel.basic_nack(
delivery_tag, requeue=requeue
)
)
else:
delivery_channel.connection.add_callback_threadsafe(
lambda: delivery_channel.basic_ack(delivery_tag)
)
except (AMQPChannelError, AMQPConnectionError) as e:
# Channel/connection errors indicate stale delivery tag - don't retry
logger.warning(
f"Cannot ack delivery_tag={delivery_tag} due to channel/connection "
f"error: {e}. Message will be redelivered by RabbitMQ."
)
else:
channel.connection.add_callback_threadsafe(
lambda: channel.basic_ack(delivery_tag)
except Exception as e:
# Other errors might be transient, but log and skip to avoid blocking
logger.error(
f"Unexpected error acking delivery_tag={delivery_tag}: {e}"
)
# Check if we're shutting down

View File

@@ -1,6 +1,7 @@
"""Redis-based distributed locking for cluster coordination."""
import logging
import threading
import time
from typing import TYPE_CHECKING
@@ -19,6 +20,7 @@ class ClusterLock:
self.owner_id = owner_id
self.timeout = timeout
self._last_refresh = 0.0
self._refresh_lock = threading.Lock()
def try_acquire(self) -> str | None:
"""Try to acquire the lock.
@@ -31,7 +33,8 @@ class ClusterLock:
try:
success = self.redis.set(self.key, self.owner_id, nx=True, ex=self.timeout)
if success:
self._last_refresh = time.time()
with self._refresh_lock:
self._last_refresh = time.time()
return self.owner_id # Successfully acquired
# Failed to acquire, get current owner
@@ -57,23 +60,27 @@ class ClusterLock:
Rate limited to at most once every timeout/10 seconds (minimum 1 second).
During rate limiting, still verifies lock existence but skips TTL extension.
Setting _last_refresh to 0 bypasses rate limiting for testing.
Thread-safe: uses _refresh_lock to protect _last_refresh access.
"""
# Calculate refresh interval: max(timeout // 10, 1)
refresh_interval = max(self.timeout // 10, 1)
current_time = time.time()
# Check if we're within the rate limit period
# Check if we're within the rate limit period (thread-safe read)
# _last_refresh == 0 forces a refresh (bypasses rate limiting for testing)
with self._refresh_lock:
last_refresh = self._last_refresh
is_rate_limited = (
self._last_refresh > 0
and (current_time - self._last_refresh) < refresh_interval
last_refresh > 0 and (current_time - last_refresh) < refresh_interval
)
try:
# Always verify lock existence, even during rate limiting
current_value = self.redis.get(self.key)
if not current_value:
self._last_refresh = 0
with self._refresh_lock:
self._last_refresh = 0
return False
stored_owner = (
@@ -82,7 +89,8 @@ class ClusterLock:
else str(current_value)
)
if stored_owner != self.owner_id:
self._last_refresh = 0
with self._refresh_lock:
self._last_refresh = 0
return False
# If rate limited, return True but don't update TTL or timestamp
@@ -91,25 +99,30 @@ class ClusterLock:
# Perform actual refresh
if self.redis.expire(self.key, self.timeout):
self._last_refresh = current_time
with self._refresh_lock:
self._last_refresh = current_time
return True
self._last_refresh = 0
with self._refresh_lock:
self._last_refresh = 0
return False
except Exception as e:
logger.error(f"ClusterLock.refresh failed for key {self.key}: {e}")
self._last_refresh = 0
with self._refresh_lock:
self._last_refresh = 0
return False
def release(self):
"""Release the lock."""
if self._last_refresh == 0:
return
with self._refresh_lock:
if self._last_refresh == 0:
return
try:
self.redis.delete(self.key)
except Exception:
pass
self._last_refresh = 0.0
with self._refresh_lock:
self._last_refresh = 0.0