mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-02-02 10:55:14 -05:00
Compare commits
2 Commits
ntindle/fi
...
fix/rabbit
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
73651d7f38 | ||
|
|
1081590384 |
@@ -291,6 +291,18 @@ class AsyncRabbitMQ(RabbitMQBase):
|
||||
exchange, routing_key=queue.routing_key or queue.name
|
||||
)
|
||||
|
||||
async def _ensure_channel(self) -> aio_pika.abc.AbstractChannel:
|
||||
"""Get a valid channel, reconnecting if the current one is stale."""
|
||||
if not self.is_ready:
|
||||
# Reset stale channel so connect() creates a fresh one
|
||||
self._channel = None
|
||||
await self.connect()
|
||||
|
||||
if self._channel is None:
|
||||
raise RuntimeError("Channel should be established after connect")
|
||||
|
||||
return self._channel
|
||||
|
||||
@func_retry
|
||||
async def publish_message(
|
||||
self,
|
||||
@@ -299,32 +311,58 @@ class AsyncRabbitMQ(RabbitMQBase):
|
||||
exchange: Optional[Exchange] = None,
|
||||
persistent: bool = True,
|
||||
) -> None:
|
||||
if not self.is_ready:
|
||||
try:
|
||||
channel = await self._ensure_channel()
|
||||
except Exception:
|
||||
# Force full reconnect on channel acquisition failure
|
||||
self._channel = None
|
||||
self._connection = None
|
||||
channel = await self._ensure_channel()
|
||||
|
||||
try:
|
||||
if exchange:
|
||||
exchange_obj = await channel.get_exchange(exchange.name)
|
||||
else:
|
||||
exchange_obj = channel.default_exchange
|
||||
|
||||
await exchange_obj.publish(
|
||||
aio_pika.Message(
|
||||
body=message.encode(),
|
||||
delivery_mode=(
|
||||
aio_pika.DeliveryMode.PERSISTENT
|
||||
if persistent
|
||||
else aio_pika.DeliveryMode.NOT_PERSISTENT
|
||||
),
|
||||
),
|
||||
routing_key=routing_key,
|
||||
)
|
||||
except aio_pika.exceptions.ChannelInvalidStateError:
|
||||
logger.warning(
|
||||
"RabbitMQ channel invalid, reconnecting and retrying publish"
|
||||
)
|
||||
self._channel = None
|
||||
self._connection = None
|
||||
await self.connect()
|
||||
|
||||
if self._channel is None:
|
||||
raise RuntimeError("Channel should be established after connect")
|
||||
if self._channel is None:
|
||||
raise RuntimeError("Channel should be established after reconnect")
|
||||
|
||||
if exchange:
|
||||
exchange_obj = await self._channel.get_exchange(exchange.name)
|
||||
else:
|
||||
exchange_obj = self._channel.default_exchange
|
||||
if exchange:
|
||||
exchange_obj = await self._channel.get_exchange(exchange.name)
|
||||
else:
|
||||
exchange_obj = self._channel.default_exchange
|
||||
|
||||
await exchange_obj.publish(
|
||||
aio_pika.Message(
|
||||
body=message.encode(),
|
||||
delivery_mode=(
|
||||
aio_pika.DeliveryMode.PERSISTENT
|
||||
if persistent
|
||||
else aio_pika.DeliveryMode.NOT_PERSISTENT
|
||||
await exchange_obj.publish(
|
||||
aio_pika.Message(
|
||||
body=message.encode(),
|
||||
delivery_mode=(
|
||||
aio_pika.DeliveryMode.PERSISTENT
|
||||
if persistent
|
||||
else aio_pika.DeliveryMode.NOT_PERSISTENT
|
||||
),
|
||||
),
|
||||
),
|
||||
routing_key=routing_key,
|
||||
)
|
||||
routing_key=routing_key,
|
||||
)
|
||||
|
||||
async def get_channel(self) -> aio_pika.abc.AbstractChannel:
|
||||
if not self.is_ready:
|
||||
await self.connect()
|
||||
if self._channel is None:
|
||||
raise RuntimeError("Channel should be established after connect")
|
||||
return self._channel
|
||||
return await self._ensure_channel()
|
||||
|
||||
@@ -0,0 +1,39 @@
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import fastapi
|
||||
from fastapi.routing import APIRoute
|
||||
|
||||
from backend.api.features.integrations.router import router as integrations_router
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.integrations.webhooks import utils as webhooks_utils
|
||||
|
||||
|
||||
def test_webhook_ingress_url_matches_route(monkeypatch) -> None:
|
||||
app = fastapi.FastAPI()
|
||||
app.include_router(integrations_router, prefix="/api/integrations")
|
||||
|
||||
provider = ProviderName.GITHUB
|
||||
webhook_id = "webhook_123"
|
||||
base_url = "https://example.com"
|
||||
|
||||
monkeypatch.setattr(webhooks_utils.app_config, "platform_base_url", base_url)
|
||||
|
||||
route = next(
|
||||
route
|
||||
for route in integrations_router.routes
|
||||
if isinstance(route, APIRoute)
|
||||
and route.path == "/{provider}/webhooks/{webhook_id}/ingress"
|
||||
and "POST" in route.methods
|
||||
)
|
||||
expected_path = f"/api/integrations{route.path}".format(
|
||||
provider=provider.value,
|
||||
webhook_id=webhook_id,
|
||||
)
|
||||
actual_url = urlparse(webhooks_utils.webhook_ingress_url(provider, webhook_id))
|
||||
expected_base = urlparse(base_url)
|
||||
|
||||
assert (actual_url.scheme, actual_url.netloc) == (
|
||||
expected_base.scheme,
|
||||
expected_base.netloc,
|
||||
)
|
||||
assert actual_url.path == expected_path
|
||||
Reference in New Issue
Block a user