mirror of
https://github.com/Pythagora-io/gpt-pilot.git
synced 2026-01-09 21:27:53 -05:00
initial setup of API in Core
This commit is contained in:
@@ -201,6 +201,24 @@ def parse_arguments() -> Namespace:
|
||||
parser.add_argument("--extension-version", help="Version of the VSCode extension", required=False)
|
||||
parser.add_argument("--use-git", help="Use Git for version control", action="store_true", required=False)
|
||||
parser.add_argument("--access-token", help="Access token", required=False)
|
||||
parser.add_argument(
|
||||
"--enable-api-server",
|
||||
action="store_true",
|
||||
default=True,
|
||||
help="Enable IPC server for external clients",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--local-api-server-host",
|
||||
type=str,
|
||||
default="localhost",
|
||||
help="Host for the IPC server (default: localhost)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--local-api-server-port",
|
||||
type=int,
|
||||
default=8222,
|
||||
help="Port for the IPC server (default: 8222)",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
|
||||
@@ -32,6 +32,7 @@ from core.llm.base import APIError
|
||||
from core.log import get_logger
|
||||
from core.state.state_manager import StateManager
|
||||
from core.telemetry import telemetry
|
||||
from core.ui.api_server import IPCServer
|
||||
from core.ui.base import (
|
||||
ProjectStage,
|
||||
UIBase,
|
||||
@@ -253,8 +254,21 @@ async def async_main(
|
||||
sm = StateManager(db, ui)
|
||||
if args.access_token:
|
||||
sm.update_access_token(args.access_token)
|
||||
|
||||
# Start API server if enabled in config
|
||||
api_server = None
|
||||
if hasattr(args, "enable_api_server") and args.enable_api_server:
|
||||
api_host = getattr(args, "local_api_server_host", "localhost")
|
||||
api_port = getattr(args, "local_api_server_port", 8222) # Different from client port
|
||||
api_server = IPCServer(api_host, api_port, sm)
|
||||
server_started = await api_server.start()
|
||||
if not server_started:
|
||||
log.warning(f"Failed to start API server on {api_host}:{api_port}")
|
||||
|
||||
ui_started = await ui.start()
|
||||
if not ui_started:
|
||||
if api_server:
|
||||
await api_server.stop()
|
||||
return False
|
||||
|
||||
telemetry.start()
|
||||
@@ -291,6 +305,8 @@ async def async_main(
|
||||
raise
|
||||
finally:
|
||||
await cleanup(ui)
|
||||
if api_server:
|
||||
await api_server.stop()
|
||||
|
||||
return success
|
||||
|
||||
|
||||
179
core/ui/api_server.py
Normal file
179
core/ui/api_server.py
Normal file
@@ -0,0 +1,179 @@
|
||||
import asyncio
|
||||
from typing import Awaitable, Callable, Dict, Optional
|
||||
|
||||
from pydantic import ValidationError
|
||||
|
||||
from core.log import get_logger
|
||||
from core.state.state_manager import StateManager
|
||||
from core.ui.ipc_client import MESSAGE_SIZE_LIMIT, Message, MessageType
|
||||
|
||||
log = get_logger(__name__)
|
||||
|
||||
|
||||
class IPCServer:
|
||||
"""
|
||||
IPC server for handling requests from external clients.
|
||||
"""
|
||||
|
||||
def __init__(self, host: str, port: int, state_manager: StateManager):
|
||||
"""
|
||||
Initialize the IPC server.
|
||||
|
||||
:param host: Host to bind to.
|
||||
:param port: Port to listen on.
|
||||
:param state_manager: State manager instance.
|
||||
"""
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.state_manager = state_manager
|
||||
self.server = None
|
||||
self.handlers: Dict[MessageType, Callable[[Message, asyncio.StreamWriter], Awaitable[None]]] = {}
|
||||
self._setup_handlers()
|
||||
|
||||
def _setup_handlers(self):
|
||||
"""Set up message handlers."""
|
||||
self.handlers[MessageType.EPICS_AND_TASKS] = self._handle_epics_and_tasks
|
||||
# Add more handlers as needed
|
||||
|
||||
async def start(self) -> bool:
|
||||
"""
|
||||
Start the IPC server.
|
||||
|
||||
:return: True if server started successfully, False otherwise.
|
||||
"""
|
||||
try:
|
||||
self.server = await asyncio.start_server(
|
||||
self._handle_client,
|
||||
self.host,
|
||||
self.port,
|
||||
limit=MESSAGE_SIZE_LIMIT,
|
||||
)
|
||||
log.info(f"IPC server started on {self.host}:{self.port}")
|
||||
return True
|
||||
except (OSError, ConnectionError) as err:
|
||||
log.error(f"Failed to start IPC server: {err}")
|
||||
return False
|
||||
|
||||
async def stop(self):
|
||||
"""Stop the IPC server."""
|
||||
if self.server:
|
||||
self.server.close()
|
||||
await self.server.wait_closed()
|
||||
log.info(f"IPC server on {self.host}:{self.port} stopped")
|
||||
self.server = None
|
||||
|
||||
async def _handle_client(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter):
|
||||
"""
|
||||
Handle client connection.
|
||||
|
||||
:param reader: Stream reader.
|
||||
:param writer: Stream writer.
|
||||
"""
|
||||
addr = writer.get_extra_info("peername")
|
||||
log.debug(f"New connection from {addr}")
|
||||
|
||||
try:
|
||||
while True:
|
||||
# Read message length (4 bytes)
|
||||
length_bytes = await reader.readexactly(4)
|
||||
if not length_bytes:
|
||||
break
|
||||
|
||||
# Parse message length
|
||||
message_length = int.from_bytes(length_bytes, byteorder="big")
|
||||
|
||||
# Read message data
|
||||
data = await reader.readexactly(message_length)
|
||||
if not data:
|
||||
break
|
||||
|
||||
# Parse message
|
||||
try:
|
||||
message = Message.from_bytes(data)
|
||||
await self._process_message(message, writer)
|
||||
except ValidationError as err:
|
||||
log.error(f"Invalid message format: {err}")
|
||||
await self._send_error(writer, "Invalid message format")
|
||||
except ValueError as err:
|
||||
log.error(f"Error decoding message: {err}")
|
||||
await self._send_error(writer, "Error decoding message")
|
||||
|
||||
except asyncio.IncompleteReadError:
|
||||
log.debug(f"Client {addr} disconnected")
|
||||
except (ConnectionResetError, BrokenPipeError) as err:
|
||||
log.debug(f"Connection to {addr} lost: {err}")
|
||||
finally:
|
||||
writer.close()
|
||||
await writer.wait_closed()
|
||||
log.debug(f"Connection to {addr} closed")
|
||||
|
||||
async def _process_message(self, message: Message, writer: asyncio.StreamWriter):
|
||||
"""
|
||||
Process incoming message.
|
||||
|
||||
:param message: Incoming message.
|
||||
:param writer: Stream writer to send response.
|
||||
"""
|
||||
log.debug(f"Received message of type {message.type} with request ID {message.request_id}")
|
||||
|
||||
handler = self.handlers.get(message.type)
|
||||
if handler:
|
||||
await handler(message, writer)
|
||||
else:
|
||||
log.warning(f"No handler for message type {message.type}")
|
||||
request_id = getattr(message, "request_id", None)
|
||||
await self._send_error(writer, f"Unsupported message type: {message.type}", request_id)
|
||||
|
||||
async def _send_response(self, writer: asyncio.StreamWriter, message: Message):
|
||||
"""
|
||||
Send response to client.
|
||||
|
||||
:param writer: Stream writer.
|
||||
:param message: Message to send.
|
||||
"""
|
||||
data = message.to_bytes()
|
||||
try:
|
||||
writer.write(len(data).to_bytes(4, byteorder="big"))
|
||||
writer.write(data)
|
||||
await writer.drain()
|
||||
except (ConnectionResetError, BrokenPipeError) as err:
|
||||
log.error(f"Failed to send response: {err}")
|
||||
|
||||
async def _send_error(self, writer: asyncio.StreamWriter, error_message: str, request_id: Optional[str] = None):
|
||||
"""
|
||||
Send error response to client.
|
||||
|
||||
:param writer: Stream writer.
|
||||
:param error_message: Error message.
|
||||
:param request_id: Optional request ID to include in the response.
|
||||
"""
|
||||
message = Message(type=MessageType.RESPONSE, content={"error": error_message}, request_id=request_id)
|
||||
await self._send_response(writer, message)
|
||||
|
||||
async def _handle_epics_and_tasks(self, message: Message, writer: asyncio.StreamWriter):
|
||||
"""
|
||||
Handle request for epics and tasks.
|
||||
|
||||
:param message: Request message.
|
||||
:param writer: Stream writer to send response.
|
||||
"""
|
||||
try:
|
||||
# Get current state
|
||||
current_state = self.state_manager.current_state
|
||||
|
||||
# Extract epics and tasks
|
||||
epics = current_state.epics if current_state.epics else []
|
||||
tasks = current_state.tasks if current_state.tasks else []
|
||||
|
||||
# Send response with the same request_id from the incoming message
|
||||
response = Message(
|
||||
type=MessageType.EPICS_AND_TASKS,
|
||||
content={"epics": epics, "tasks": tasks},
|
||||
request_id=message.request_id, # Include the request_id from the incoming message
|
||||
)
|
||||
log.debug(f"Sending epics and tasks response with request_id: {message.request_id}")
|
||||
await self._send_response(writer, response)
|
||||
|
||||
except Exception as err:
|
||||
log.error(f"Error handling epics and tasks request: {err}", exc_info=True)
|
||||
await self._send_error(writer, f"Internal server error: {str(err)}", message.request_id)
|
||||
@@ -70,6 +70,7 @@ class Message(BaseModel):
|
||||
* `extra_info`: Additional information (eg. "This is a hint"), optional
|
||||
* `placeholder`: Placeholder for user input, optional
|
||||
* `access_token`: Access token for user input, optional
|
||||
* `request_id`: Unique identifier for request-response matching, optional
|
||||
"""
|
||||
|
||||
type: MessageType
|
||||
@@ -80,6 +81,7 @@ class Message(BaseModel):
|
||||
content: Union[str, dict, None] = None
|
||||
placeholder: Optional[str] = None
|
||||
accessToken: Optional[str] = None
|
||||
request_id: Optional[str] = None
|
||||
|
||||
def to_bytes(self) -> bytes:
|
||||
"""
|
||||
|
||||
@@ -44,6 +44,9 @@ def test_parse_arguments(mock_ArgumentParser):
|
||||
"--database",
|
||||
"--local-ipc-port",
|
||||
"--local-ipc-host",
|
||||
"--enable-api-server",
|
||||
"--local-api-server-host",
|
||||
"--local-api-server-port",
|
||||
"--version",
|
||||
"--list",
|
||||
"--list-json",
|
||||
|
||||
@@ -105,6 +105,7 @@ async def test_send_message():
|
||||
"extra_info": "test",
|
||||
"placeholder": None,
|
||||
"accessToken": None,
|
||||
"request_id": None,
|
||||
},
|
||||
{
|
||||
"type": "exit",
|
||||
@@ -115,6 +116,7 @@ async def test_send_message():
|
||||
"extra_info": None,
|
||||
"placeholder": None,
|
||||
"accessToken": None,
|
||||
"request_id": None,
|
||||
},
|
||||
]
|
||||
|
||||
@@ -145,6 +147,7 @@ async def test_stream():
|
||||
"extra_info": None,
|
||||
"placeholder": None,
|
||||
"accessToken": None,
|
||||
"request_id": None,
|
||||
},
|
||||
{
|
||||
"type": "stream",
|
||||
@@ -155,6 +158,7 @@ async def test_stream():
|
||||
"extra_info": None,
|
||||
"placeholder": None,
|
||||
"accessToken": None,
|
||||
"request_id": None,
|
||||
},
|
||||
{
|
||||
"type": "exit",
|
||||
@@ -165,6 +169,7 @@ async def test_stream():
|
||||
"extra_info": None,
|
||||
"placeholder": None,
|
||||
"accessToken": None,
|
||||
"request_id": None,
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
Reference in New Issue
Block a user