initial setup of API in Core

This commit is contained in:
LeonOstrez
2025-05-09 10:21:17 +02:00
parent d03f15d91a
commit 478783edbf
6 changed files with 223 additions and 0 deletions

View File

@@ -201,6 +201,24 @@ def parse_arguments() -> Namespace:
parser.add_argument("--extension-version", help="Version of the VSCode extension", required=False)
parser.add_argument("--use-git", help="Use Git for version control", action="store_true", required=False)
parser.add_argument("--access-token", help="Access token", required=False)
parser.add_argument(
"--enable-api-server",
action="store_true",
default=True,
help="Enable IPC server for external clients",
)
parser.add_argument(
"--local-api-server-host",
type=str,
default="localhost",
help="Host for the IPC server (default: localhost)",
)
parser.add_argument(
"--local-api-server-port",
type=int,
default=8222,
help="Port for the IPC server (default: 8222)",
)
return parser.parse_args()

View File

@@ -32,6 +32,7 @@ from core.llm.base import APIError
from core.log import get_logger
from core.state.state_manager import StateManager
from core.telemetry import telemetry
from core.ui.api_server import IPCServer
from core.ui.base import (
ProjectStage,
UIBase,
@@ -253,8 +254,21 @@ async def async_main(
sm = StateManager(db, ui)
if args.access_token:
sm.update_access_token(args.access_token)
# Start API server if enabled in config
api_server = None
if hasattr(args, "enable_api_server") and args.enable_api_server:
api_host = getattr(args, "local_api_server_host", "localhost")
api_port = getattr(args, "local_api_server_port", 8222) # Different from client port
api_server = IPCServer(api_host, api_port, sm)
server_started = await api_server.start()
if not server_started:
log.warning(f"Failed to start API server on {api_host}:{api_port}")
ui_started = await ui.start()
if not ui_started:
if api_server:
await api_server.stop()
return False
telemetry.start()
@@ -291,6 +305,8 @@ async def async_main(
raise
finally:
await cleanup(ui)
if api_server:
await api_server.stop()
return success

179
core/ui/api_server.py Normal file
View File

@@ -0,0 +1,179 @@
import asyncio
from typing import Awaitable, Callable, Dict, Optional
from pydantic import ValidationError
from core.log import get_logger
from core.state.state_manager import StateManager
from core.ui.ipc_client import MESSAGE_SIZE_LIMIT, Message, MessageType
log = get_logger(__name__)
class IPCServer:
"""
IPC server for handling requests from external clients.
"""
def __init__(self, host: str, port: int, state_manager: StateManager):
"""
Initialize the IPC server.
:param host: Host to bind to.
:param port: Port to listen on.
:param state_manager: State manager instance.
"""
self.host = host
self.port = port
self.state_manager = state_manager
self.server = None
self.handlers: Dict[MessageType, Callable[[Message, asyncio.StreamWriter], Awaitable[None]]] = {}
self._setup_handlers()
def _setup_handlers(self):
"""Set up message handlers."""
self.handlers[MessageType.EPICS_AND_TASKS] = self._handle_epics_and_tasks
# Add more handlers as needed
async def start(self) -> bool:
"""
Start the IPC server.
:return: True if server started successfully, False otherwise.
"""
try:
self.server = await asyncio.start_server(
self._handle_client,
self.host,
self.port,
limit=MESSAGE_SIZE_LIMIT,
)
log.info(f"IPC server started on {self.host}:{self.port}")
return True
except (OSError, ConnectionError) as err:
log.error(f"Failed to start IPC server: {err}")
return False
async def stop(self):
"""Stop the IPC server."""
if self.server:
self.server.close()
await self.server.wait_closed()
log.info(f"IPC server on {self.host}:{self.port} stopped")
self.server = None
async def _handle_client(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter):
"""
Handle client connection.
:param reader: Stream reader.
:param writer: Stream writer.
"""
addr = writer.get_extra_info("peername")
log.debug(f"New connection from {addr}")
try:
while True:
# Read message length (4 bytes)
length_bytes = await reader.readexactly(4)
if not length_bytes:
break
# Parse message length
message_length = int.from_bytes(length_bytes, byteorder="big")
# Read message data
data = await reader.readexactly(message_length)
if not data:
break
# Parse message
try:
message = Message.from_bytes(data)
await self._process_message(message, writer)
except ValidationError as err:
log.error(f"Invalid message format: {err}")
await self._send_error(writer, "Invalid message format")
except ValueError as err:
log.error(f"Error decoding message: {err}")
await self._send_error(writer, "Error decoding message")
except asyncio.IncompleteReadError:
log.debug(f"Client {addr} disconnected")
except (ConnectionResetError, BrokenPipeError) as err:
log.debug(f"Connection to {addr} lost: {err}")
finally:
writer.close()
await writer.wait_closed()
log.debug(f"Connection to {addr} closed")
async def _process_message(self, message: Message, writer: asyncio.StreamWriter):
"""
Process incoming message.
:param message: Incoming message.
:param writer: Stream writer to send response.
"""
log.debug(f"Received message of type {message.type} with request ID {message.request_id}")
handler = self.handlers.get(message.type)
if handler:
await handler(message, writer)
else:
log.warning(f"No handler for message type {message.type}")
request_id = getattr(message, "request_id", None)
await self._send_error(writer, f"Unsupported message type: {message.type}", request_id)
async def _send_response(self, writer: asyncio.StreamWriter, message: Message):
"""
Send response to client.
:param writer: Stream writer.
:param message: Message to send.
"""
data = message.to_bytes()
try:
writer.write(len(data).to_bytes(4, byteorder="big"))
writer.write(data)
await writer.drain()
except (ConnectionResetError, BrokenPipeError) as err:
log.error(f"Failed to send response: {err}")
async def _send_error(self, writer: asyncio.StreamWriter, error_message: str, request_id: Optional[str] = None):
"""
Send error response to client.
:param writer: Stream writer.
:param error_message: Error message.
:param request_id: Optional request ID to include in the response.
"""
message = Message(type=MessageType.RESPONSE, content={"error": error_message}, request_id=request_id)
await self._send_response(writer, message)
async def _handle_epics_and_tasks(self, message: Message, writer: asyncio.StreamWriter):
"""
Handle request for epics and tasks.
:param message: Request message.
:param writer: Stream writer to send response.
"""
try:
# Get current state
current_state = self.state_manager.current_state
# Extract epics and tasks
epics = current_state.epics if current_state.epics else []
tasks = current_state.tasks if current_state.tasks else []
# Send response with the same request_id from the incoming message
response = Message(
type=MessageType.EPICS_AND_TASKS,
content={"epics": epics, "tasks": tasks},
request_id=message.request_id, # Include the request_id from the incoming message
)
log.debug(f"Sending epics and tasks response with request_id: {message.request_id}")
await self._send_response(writer, response)
except Exception as err:
log.error(f"Error handling epics and tasks request: {err}", exc_info=True)
await self._send_error(writer, f"Internal server error: {str(err)}", message.request_id)

View File

@@ -70,6 +70,7 @@ class Message(BaseModel):
* `extra_info`: Additional information (eg. "This is a hint"), optional
* `placeholder`: Placeholder for user input, optional
* `access_token`: Access token for user input, optional
* `request_id`: Unique identifier for request-response matching, optional
"""
type: MessageType
@@ -80,6 +81,7 @@ class Message(BaseModel):
content: Union[str, dict, None] = None
placeholder: Optional[str] = None
accessToken: Optional[str] = None
request_id: Optional[str] = None
def to_bytes(self) -> bytes:
"""

View File

@@ -44,6 +44,9 @@ def test_parse_arguments(mock_ArgumentParser):
"--database",
"--local-ipc-port",
"--local-ipc-host",
"--enable-api-server",
"--local-api-server-host",
"--local-api-server-port",
"--version",
"--list",
"--list-json",

View File

@@ -105,6 +105,7 @@ async def test_send_message():
"extra_info": "test",
"placeholder": None,
"accessToken": None,
"request_id": None,
},
{
"type": "exit",
@@ -115,6 +116,7 @@ async def test_send_message():
"extra_info": None,
"placeholder": None,
"accessToken": None,
"request_id": None,
},
]
@@ -145,6 +147,7 @@ async def test_stream():
"extra_info": None,
"placeholder": None,
"accessToken": None,
"request_id": None,
},
{
"type": "stream",
@@ -155,6 +158,7 @@ async def test_stream():
"extra_info": None,
"placeholder": None,
"accessToken": None,
"request_id": None,
},
{
"type": "exit",
@@ -165,6 +169,7 @@ async def test_stream():
"extra_info": None,
"placeholder": None,
"accessToken": None,
"request_id": None,
},
]