mirror of
https://github.com/All-Hands-AI/OpenHands.git
synced 2026-01-10 07:18:10 -05:00
Rename OpenDevin to OpenHands (#3472)
* Replace OpenDevin with OpenHands * Update CONTRIBUTING.md * Update README.md * Update README.md * update poetry lock; move opendevin folder to openhands * fix env var * revert image references in docs * revert permissions * revert permissions --------- Co-authored-by: Xingyao Wang <xingyao6@illinois.edu>
This commit is contained in:
73
openhands/security/README.md
Normal file
73
openhands/security/README.md
Normal file
@@ -0,0 +1,73 @@
|
||||
# Security
|
||||
|
||||
Given the impressive capabilities of OpenHands and similar coding agents, ensuring robust security measures is essential to prevent unintended actions or security breaches. The SecurityAnalyzer framework provides a structured approach to monitor and analyze agent actions for potential security risks.
|
||||
|
||||
To enable this feature:
|
||||
* From the web interface
|
||||
* Open Configuration (by clicking the gear icon in the bottom right)
|
||||
* Select a Security Analyzer from the dropdown
|
||||
* Save settings
|
||||
* (to disable) repeat the same steps, but click the X in the Security Analyzer dropdown
|
||||
* From config.toml
|
||||
```toml
|
||||
[security]
|
||||
# Enable confirmation mode
|
||||
confirmation_mode = true
|
||||
# The security analyzer to use
|
||||
security_analyzer = "your-security-analyzer"
|
||||
```
|
||||
(to disable) remove the lines from config.toml
|
||||
|
||||
## SecurityAnalyzer Base Class
|
||||
|
||||
The `SecurityAnalyzer` class (analyzer.py) is an abstract base class designed to listen to an event stream and analyze actions for security risks and eventually act before the action is executed. Below is a detailed explanation of its components and methods:
|
||||
|
||||
### Initialization
|
||||
|
||||
- **event_stream**: An instance of `EventStream` that the analyzer will listen to for events.
|
||||
|
||||
### Event Handling
|
||||
|
||||
- **on_event(event: Event)**: Handles incoming events. If the event is an `Action`, it evaluates its security risk and acts upon it.
|
||||
|
||||
### Abstract Methods
|
||||
|
||||
- **handle_api_request(request: Request)**: Abstract method to handle API requests.
|
||||
- **log_event(event: Event)**: Logs events.
|
||||
- **act(event: Event)**: Defines actions to take based on the analyzed event.
|
||||
- **security_risk(event: Action)**: Evaluates the security risk of an action and returns the risk level.
|
||||
- **close()**: Cleanups resources used by the security analyzer.
|
||||
|
||||
In conclusion, a concrete security analyzer should evaluate the risk of each event and act accordingly (e.g. auto-confirm, send Slack message, etc).
|
||||
|
||||
For customization and decoupling from the OpenHands core logic, the security analyzer can define its own API endpoints that can then be accessed from the frontend. These API endpoints need to be secured (do not allow more capabilities than the core logic
|
||||
provides).
|
||||
|
||||
## How to implement your own Security Analyzer
|
||||
|
||||
1. Create a submodule in [security](/openhands/security/) with your analyzer's desired name
|
||||
* Have your main class inherit from [SecurityAnalyzer](/openhands/security/analyzer.py)
|
||||
* Optional: define API endpoints for `/api/security/{path:path}` to manage settings,
|
||||
2. Add your analyzer class to the [options](/openhands/security/options.py) to have it be visible from the frontend combobox
|
||||
3. Optional: implement your modal frontend (for when you click on the lock) in [security](/frontend/src/components/modals/security/) and add your component to [Security.tsx](/frontend/src/components/modals/security/Security.tsx)
|
||||
|
||||
## Implemented Security Analyzers
|
||||
|
||||
### Invariant
|
||||
|
||||
It uses the [Invariant Analyzer](https://github.com/invariantlabs-ai/invariant) to analyze traces and detect potential issues with OpenHands's workflow. It uses confirmation mode to ask for user confirmation on potentially risky actions.
|
||||
|
||||
This allows the agent to run autonomously without fear that it will inadvertently compromise security or perform unintended actions that could be harmful.
|
||||
|
||||
Features:
|
||||
|
||||
* Detects:
|
||||
* potential secret leaks by the agent
|
||||
* security issues in Python code
|
||||
* malicious bash commands
|
||||
* Logs:
|
||||
* actions and their associated risk
|
||||
* OpenHands traces in JSON format
|
||||
* Run-time settings:
|
||||
* the [invariant policy](https://github.com/invariantlabs-ai/invariant?tab=readme-ov-file#policy-language)
|
||||
* acceptable risk threshold
|
||||
7
openhands/security/__init__.py
Normal file
7
openhands/security/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
from .analyzer import SecurityAnalyzer
|
||||
from .invariant.analyzer import InvariantAnalyzer
|
||||
|
||||
__all__ = [
|
||||
'SecurityAnalyzer',
|
||||
'InvariantAnalyzer',
|
||||
]
|
||||
60
openhands/security/analyzer.py
Normal file
60
openhands/security/analyzer.py
Normal file
@@ -0,0 +1,60 @@
|
||||
from typing import Any
|
||||
|
||||
from fastapi import Request
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.events.action.action import Action, ActionSecurityRisk
|
||||
from openhands.events.event import Event
|
||||
from openhands.events.stream import EventStream, EventStreamSubscriber
|
||||
|
||||
|
||||
class SecurityAnalyzer:
|
||||
"""Security analyzer that receives all events and analyzes agent actions for security risks."""
|
||||
|
||||
def __init__(self, event_stream: EventStream):
|
||||
"""Initializes a new instance of the SecurityAnalyzer class.
|
||||
|
||||
Args:
|
||||
event_stream: The event stream to listen for events.
|
||||
"""
|
||||
self.event_stream = event_stream
|
||||
self.event_stream.subscribe(
|
||||
EventStreamSubscriber.SECURITY_ANALYZER, self.on_event
|
||||
)
|
||||
|
||||
async def on_event(self, event: Event) -> None:
|
||||
"""Handles the incoming event, and when Action is received, analyzes it for security risks."""
|
||||
logger.info(f'SecurityAnalyzer received event: {event}')
|
||||
await self.log_event(event)
|
||||
if not isinstance(event, Action):
|
||||
return
|
||||
|
||||
try:
|
||||
event.security_risk = await self.security_risk(event) # type: ignore [attr-defined]
|
||||
await self.act(event)
|
||||
except Exception as e:
|
||||
logger.error(f'Error occurred while analyzing the event: {e}')
|
||||
|
||||
async def handle_api_request(self, request: Request) -> Any:
|
||||
"""Handles the incoming API request."""
|
||||
raise NotImplementedError(
|
||||
'Need to implement handle_api_request method in SecurityAnalyzer subclass'
|
||||
)
|
||||
|
||||
async def log_event(self, event: Event) -> None:
|
||||
"""Logs the incoming event."""
|
||||
pass
|
||||
|
||||
async def act(self, event: Event) -> None:
|
||||
"""Performs an action based on the analyzed event."""
|
||||
pass
|
||||
|
||||
async def security_risk(self, event: Action) -> ActionSecurityRisk:
|
||||
"""Evaluates the Action for security risks and returns the risk level."""
|
||||
raise NotImplementedError(
|
||||
'Need to implement security_risk method in SecurityAnalyzer subclass'
|
||||
)
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Cleanup resources allocated by the SecurityAnalyzer."""
|
||||
pass
|
||||
5
openhands/security/invariant/__init__.py
Normal file
5
openhands/security/invariant/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
from .analyzer import InvariantAnalyzer
|
||||
|
||||
__all__ = [
|
||||
'InvariantAnalyzer',
|
||||
]
|
||||
206
openhands/security/invariant/analyzer.py
Normal file
206
openhands/security/invariant/analyzer.py
Normal file
@@ -0,0 +1,206 @@
|
||||
import re
|
||||
import uuid
|
||||
from typing import Any
|
||||
|
||||
import docker
|
||||
from fastapi import HTTPException, Request
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.events.action.action import (
|
||||
Action,
|
||||
ActionSecurityRisk,
|
||||
)
|
||||
from openhands.events.event import Event, EventSource
|
||||
from openhands.events.observation import Observation
|
||||
from openhands.events.serialization.action import action_from_dict
|
||||
from openhands.events.stream import EventStream
|
||||
from openhands.runtime.utils import find_available_tcp_port
|
||||
from openhands.security.analyzer import SecurityAnalyzer
|
||||
from openhands.security.invariant.client import InvariantClient
|
||||
from openhands.security.invariant.parser import TraceElement, parse_element
|
||||
|
||||
|
||||
class InvariantAnalyzer(SecurityAnalyzer):
|
||||
"""Security analyzer based on Invariant."""
|
||||
|
||||
trace: list[TraceElement]
|
||||
input: list[dict]
|
||||
container_name: str = 'openhands-invariant-server'
|
||||
image_name: str = 'ghcr.io/invariantlabs-ai/server:openhands'
|
||||
api_host: str = 'http://localhost'
|
||||
timeout: int = 180
|
||||
settings: dict = {}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
event_stream: EventStream,
|
||||
policy: str | None = None,
|
||||
sid: str | None = None,
|
||||
):
|
||||
"""Initializes a new instance of the InvariantAnalzyer class."""
|
||||
super().__init__(event_stream)
|
||||
self.trace = []
|
||||
self.input = []
|
||||
self.settings = {}
|
||||
if sid is None:
|
||||
self.sid = str(uuid.uuid4())
|
||||
|
||||
try:
|
||||
self.docker_client = docker.from_env()
|
||||
except Exception as ex:
|
||||
logger.exception(
|
||||
'Error creating Invariant Security Analyzer container. Please check that Docker is running or disable the Security Analyzer in settings.',
|
||||
exc_info=False,
|
||||
)
|
||||
raise ex
|
||||
running_containers = self.docker_client.containers.list(
|
||||
filters={'name': self.container_name}
|
||||
)
|
||||
if not running_containers:
|
||||
all_containers = self.docker_client.containers.list(
|
||||
all=True, filters={'name': self.container_name}
|
||||
)
|
||||
if all_containers:
|
||||
self.container = all_containers[0]
|
||||
all_containers[0].start()
|
||||
else:
|
||||
self.api_port = find_available_tcp_port()
|
||||
self.container = self.docker_client.containers.run(
|
||||
self.image_name,
|
||||
name=self.container_name,
|
||||
platform='linux/amd64',
|
||||
ports={'8000/tcp': self.api_port},
|
||||
detach=True,
|
||||
)
|
||||
else:
|
||||
self.container = running_containers[0]
|
||||
|
||||
elapsed = 0
|
||||
while self.container.status != 'running':
|
||||
self.container = self.docker_client.containers.get(self.container_name)
|
||||
elapsed += 1
|
||||
logger.info(
|
||||
f'waiting for container to start: {elapsed}, container status: {self.container.status}'
|
||||
)
|
||||
if elapsed > self.timeout:
|
||||
break
|
||||
|
||||
self.api_port = int(
|
||||
self.container.attrs['NetworkSettings']['Ports']['8000/tcp'][0]['HostPort']
|
||||
)
|
||||
|
||||
self.api_server = f'{self.api_host}:{self.api_port}'
|
||||
self.client = InvariantClient(self.api_server, self.sid)
|
||||
if policy is None:
|
||||
policy, _ = self.client.Policy.get_template()
|
||||
if policy is None:
|
||||
policy = ''
|
||||
self.monitor = self.client.Monitor.from_string(policy)
|
||||
|
||||
async def close(self):
|
||||
self.container.stop()
|
||||
|
||||
async def log_event(self, event: Event) -> None:
|
||||
if isinstance(event, Observation):
|
||||
element = parse_element(self.trace, event)
|
||||
self.trace.extend(element)
|
||||
self.input.extend([e.model_dump(exclude_none=True) for e in element]) # type: ignore [call-overload]
|
||||
else:
|
||||
logger.info('Invariant skipping element: event')
|
||||
|
||||
def get_risk(self, results: list[str]) -> ActionSecurityRisk:
|
||||
mapping = {
|
||||
'high': ActionSecurityRisk.HIGH,
|
||||
'medium': ActionSecurityRisk.MEDIUM,
|
||||
'low': ActionSecurityRisk.LOW,
|
||||
}
|
||||
regex = r'(?<=risk=)\w+'
|
||||
risks = []
|
||||
for result in results:
|
||||
m = re.search(regex, result)
|
||||
if m and m.group() in mapping:
|
||||
risks.append(mapping[m.group()])
|
||||
|
||||
if risks:
|
||||
return max(risks)
|
||||
|
||||
return ActionSecurityRisk.LOW
|
||||
|
||||
async def act(self, event: Event) -> None:
|
||||
if await self.should_confirm(event):
|
||||
await self.confirm(event)
|
||||
|
||||
async def should_confirm(self, event: Event) -> bool:
|
||||
risk = event.security_risk # type: ignore [attr-defined]
|
||||
return (
|
||||
risk is not None
|
||||
and risk < self.settings.get('RISK_SEVERITY', ActionSecurityRisk.MEDIUM)
|
||||
and hasattr(event, 'is_confirmed')
|
||||
and event.is_confirmed == 'awaiting_confirmation'
|
||||
)
|
||||
|
||||
async def confirm(self, event: Event) -> None:
|
||||
new_event = action_from_dict(
|
||||
{'action': 'change_agent_state', 'args': {'agent_state': 'user_confirmed'}}
|
||||
)
|
||||
if event.source:
|
||||
self.event_stream.add_event(new_event, event.source)
|
||||
else:
|
||||
self.event_stream.add_event(new_event, EventSource.AGENT)
|
||||
|
||||
async def security_risk(self, event: Action) -> ActionSecurityRisk:
|
||||
logger.info('Calling security_risk on InvariantAnalyzer')
|
||||
new_elements = parse_element(self.trace, event)
|
||||
input = [e.model_dump(exclude_none=True) for e in new_elements] # type: ignore [call-overload]
|
||||
self.trace.extend(new_elements)
|
||||
result, err = self.monitor.check(self.input, input)
|
||||
self.input.extend(input)
|
||||
risk = ActionSecurityRisk.UNKNOWN
|
||||
if err:
|
||||
logger.warning(f'Error checking policy: {err}')
|
||||
return risk
|
||||
|
||||
risk = self.get_risk(result)
|
||||
|
||||
return risk
|
||||
|
||||
### Handle API requests
|
||||
async def handle_api_request(self, request: Request) -> Any:
|
||||
path_parts = request.url.path.strip('/').split('/')
|
||||
endpoint = path_parts[-1] # Get the last part of the path
|
||||
|
||||
if request.method == 'GET':
|
||||
if endpoint == 'export-trace':
|
||||
return await self.export_trace(request)
|
||||
elif endpoint == 'policy':
|
||||
return await self.get_policy(request)
|
||||
elif endpoint == 'settings':
|
||||
return await self.get_settings(request)
|
||||
elif request.method == 'POST':
|
||||
if endpoint == 'policy':
|
||||
return await self.update_policy(request)
|
||||
elif endpoint == 'settings':
|
||||
return await self.update_settings(request)
|
||||
raise HTTPException(status_code=405, detail='Method Not Allowed')
|
||||
|
||||
async def export_trace(self, request: Request) -> Any:
|
||||
return JSONResponse(content=self.input)
|
||||
|
||||
async def get_policy(self, request: Request) -> Any:
|
||||
return JSONResponse(content={'policy': self.monitor.policy})
|
||||
|
||||
async def update_policy(self, request: Request) -> Any:
|
||||
data = await request.json()
|
||||
policy = data.get('policy')
|
||||
new_monitor = self.client.Monitor.from_string(policy)
|
||||
self.monitor = new_monitor
|
||||
return JSONResponse(content={'policy': policy})
|
||||
|
||||
async def get_settings(self, request: Request) -> Any:
|
||||
return JSONResponse(content=self.settings)
|
||||
|
||||
async def update_settings(self, request: Request) -> Any:
|
||||
settings = await request.json()
|
||||
self.settings = settings
|
||||
return JSONResponse(content=self.settings)
|
||||
137
openhands/security/invariant/client.py
Normal file
137
openhands/security/invariant/client.py
Normal file
@@ -0,0 +1,137 @@
|
||||
import time
|
||||
from typing import Any, Union
|
||||
|
||||
import requests
|
||||
from requests.exceptions import ConnectionError, HTTPError, Timeout
|
||||
|
||||
|
||||
class InvariantClient:
|
||||
timeout: int = 120
|
||||
|
||||
def __init__(self, server_url: str, session_id: str | None = None):
|
||||
self.server = server_url
|
||||
self.session_id, err = self._create_session(session_id)
|
||||
if err:
|
||||
raise RuntimeError(f'Failed to create session: {err}')
|
||||
self.Policy = self._Policy(self)
|
||||
self.Monitor = self._Monitor(self)
|
||||
|
||||
def _create_session(
|
||||
self, session_id: str | None = None
|
||||
) -> tuple[str | None, Exception | None]:
|
||||
elapsed = 0
|
||||
while elapsed < self.timeout:
|
||||
try:
|
||||
if session_id:
|
||||
response = requests.get(
|
||||
f'{self.server}/session/new?session_id={session_id}', timeout=60
|
||||
)
|
||||
else:
|
||||
response = requests.get(f'{self.server}/session/new', timeout=60)
|
||||
response.raise_for_status()
|
||||
return response.json().get('id'), None
|
||||
except (ConnectionError, Timeout):
|
||||
elapsed += 1
|
||||
time.sleep(1)
|
||||
except HTTPError as http_err:
|
||||
return None, http_err
|
||||
except Exception as err:
|
||||
return None, err
|
||||
return None, ConnectionError('Connection timed out')
|
||||
|
||||
def close_session(self) -> Union[None, Exception]:
|
||||
try:
|
||||
response = requests.delete(
|
||||
f'{self.server}/session/?session_id={self.session_id}', timeout=60
|
||||
)
|
||||
response.raise_for_status()
|
||||
except (ConnectionError, Timeout, HTTPError) as err:
|
||||
return err
|
||||
return None
|
||||
|
||||
class _Policy:
|
||||
def __init__(self, invariant):
|
||||
self.server = invariant.server
|
||||
self.session_id = invariant.session_id
|
||||
|
||||
def _create_policy(self, rule: str) -> tuple[str | None, Exception | None]:
|
||||
try:
|
||||
response = requests.post(
|
||||
f'{self.server}/policy/new?session_id={self.session_id}',
|
||||
json={'rule': rule},
|
||||
timeout=60,
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.json().get('policy_id'), None
|
||||
except (ConnectionError, Timeout, HTTPError) as err:
|
||||
return None, err
|
||||
|
||||
def get_template(self) -> tuple[str | None, Exception | None]:
|
||||
try:
|
||||
response = requests.get(
|
||||
f'{self.server}/policy/template',
|
||||
timeout=60,
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.json(), None
|
||||
except (ConnectionError, Timeout, HTTPError) as err:
|
||||
return None, err
|
||||
|
||||
def from_string(self, rule: str):
|
||||
policy_id, err = self._create_policy(rule)
|
||||
if err:
|
||||
raise err
|
||||
self.policy_id = policy_id
|
||||
return self
|
||||
|
||||
def analyze(self, trace: list[dict]) -> Union[Any, Exception]:
|
||||
try:
|
||||
response = requests.post(
|
||||
f'{self.server}/policy/{self.policy_id}/analyze?session_id={self.session_id}',
|
||||
json={'trace': trace},
|
||||
timeout=60,
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.json(), None
|
||||
except (ConnectionError, Timeout, HTTPError) as err:
|
||||
return None, err
|
||||
|
||||
class _Monitor:
|
||||
def __init__(self, invariant):
|
||||
self.server = invariant.server
|
||||
self.session_id = invariant.session_id
|
||||
self.policy = ''
|
||||
|
||||
def _create_monitor(self, rule: str) -> tuple[str | None, Exception | None]:
|
||||
try:
|
||||
response = requests.post(
|
||||
f'{self.server}/monitor/new?session_id={self.session_id}',
|
||||
json={'rule': rule},
|
||||
timeout=60,
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.json().get('monitor_id'), None
|
||||
except (ConnectionError, Timeout, HTTPError) as err:
|
||||
return None, err
|
||||
|
||||
def from_string(self, rule: str):
|
||||
monitor_id, err = self._create_monitor(rule)
|
||||
if err:
|
||||
raise err
|
||||
self.monitor_id = monitor_id
|
||||
self.policy = rule
|
||||
return self
|
||||
|
||||
def check(
|
||||
self, past_events: list[dict], pending_events: list[dict]
|
||||
) -> Union[Any, Exception]:
|
||||
try:
|
||||
response = requests.post(
|
||||
f'{self.server}/monitor/{self.monitor_id}/check?session_id={self.session_id}',
|
||||
json={'past_events': past_events, 'pending_events': pending_events},
|
||||
timeout=60,
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.json(), None
|
||||
except (ConnectionError, Timeout, HTTPError) as err:
|
||||
return None, err
|
||||
45
openhands/security/invariant/nodes.py
Normal file
45
openhands/security/invariant/nodes.py
Normal file
@@ -0,0 +1,45 @@
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic.dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class LLM:
|
||||
vendor: str
|
||||
model: str
|
||||
|
||||
|
||||
class Event(BaseModel):
|
||||
metadata: dict | None = Field(
|
||||
default_factory=dict, description='Metadata associated with the event'
|
||||
)
|
||||
|
||||
|
||||
class Function(BaseModel):
|
||||
name: str
|
||||
arguments: dict
|
||||
|
||||
|
||||
class ToolCall(Event):
|
||||
id: str
|
||||
type: str
|
||||
function: Function
|
||||
|
||||
|
||||
class Message(Event):
|
||||
role: str
|
||||
content: str | None
|
||||
tool_calls: list[ToolCall] | None = None
|
||||
|
||||
def __rich_repr__(self):
|
||||
# Print on separate line
|
||||
yield 'role', self.role
|
||||
yield 'content', self.content
|
||||
yield 'tool_calls', self.tool_calls
|
||||
|
||||
|
||||
class ToolOutput(Event):
|
||||
role: str
|
||||
content: str
|
||||
tool_call_id: str | None = None
|
||||
|
||||
_tool_call: ToolCall | None = None
|
||||
103
openhands/security/invariant/parser.py
Normal file
103
openhands/security/invariant/parser.py
Normal file
@@ -0,0 +1,103 @@
|
||||
from typing import Union
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.events.action import (
|
||||
Action,
|
||||
ChangeAgentStateAction,
|
||||
MessageAction,
|
||||
NullAction,
|
||||
)
|
||||
from openhands.events.event import EventSource
|
||||
from openhands.events.observation import (
|
||||
AgentStateChangedObservation,
|
||||
NullObservation,
|
||||
Observation,
|
||||
)
|
||||
from openhands.events.serialization.event import event_to_dict
|
||||
from openhands.security.invariant.nodes import Function, Message, ToolCall, ToolOutput
|
||||
|
||||
TraceElement = Union[Message, ToolCall, ToolOutput, Function]
|
||||
|
||||
|
||||
def get_next_id(trace: list[TraceElement]) -> str:
|
||||
used_ids = [el.id for el in trace if type(el) == ToolCall]
|
||||
for i in range(1, len(used_ids) + 2):
|
||||
if str(i) not in used_ids:
|
||||
return str(i)
|
||||
return '1'
|
||||
|
||||
|
||||
def get_last_id(
|
||||
trace: list[TraceElement],
|
||||
) -> str | None:
|
||||
for el in reversed(trace):
|
||||
if type(el) == ToolCall:
|
||||
return el.id
|
||||
return None
|
||||
|
||||
|
||||
def parse_action(trace: list[TraceElement], action: Action) -> list[TraceElement]:
|
||||
next_id = get_next_id(trace)
|
||||
inv_trace = [] # type: list[TraceElement]
|
||||
if type(action) == MessageAction:
|
||||
if action.source == EventSource.USER:
|
||||
inv_trace.append(Message(role='user', content=action.content))
|
||||
else:
|
||||
inv_trace.append(Message(role='assistant', content=action.content))
|
||||
elif type(action) in [NullAction, ChangeAgentStateAction]:
|
||||
pass
|
||||
elif hasattr(action, 'action') and action.action is not None:
|
||||
event_dict = event_to_dict(action)
|
||||
args = event_dict.get('args', {})
|
||||
thought = args.pop('thought', None)
|
||||
function = Function(name=action.action, arguments=args)
|
||||
if thought is not None:
|
||||
inv_trace.append(Message(role='assistant', content=thought))
|
||||
inv_trace.append(ToolCall(id=next_id, type='function', function=function))
|
||||
else:
|
||||
logger.error(f'Unknown action type: {type(action)}')
|
||||
return inv_trace
|
||||
|
||||
|
||||
def parse_observation(
|
||||
trace: list[TraceElement], obs: Observation
|
||||
) -> list[TraceElement]:
|
||||
last_id = get_last_id(trace)
|
||||
if type(obs) in [NullObservation, AgentStateChangedObservation]:
|
||||
return []
|
||||
elif hasattr(obs, 'content') and obs.content is not None:
|
||||
return [ToolOutput(role='tool', content=obs.content, tool_call_id=last_id)]
|
||||
else:
|
||||
logger.error(f'Unknown observation type: {type(obs)}')
|
||||
return []
|
||||
|
||||
|
||||
def parse_element(
|
||||
trace: list[TraceElement], element: Action | Observation
|
||||
) -> list[TraceElement]:
|
||||
if isinstance(element, Action):
|
||||
return parse_action(trace, element)
|
||||
return parse_observation(trace, element)
|
||||
|
||||
|
||||
def parse_trace(trace: list[tuple[Action, Observation]]):
|
||||
inv_trace = [] # type: list[TraceElement]
|
||||
for action, obs in trace:
|
||||
inv_trace.extend(parse_action(inv_trace, action))
|
||||
inv_trace.extend(parse_observation(inv_trace, obs))
|
||||
return inv_trace
|
||||
|
||||
|
||||
class InvariantState(BaseModel):
|
||||
trace: list[TraceElement] = Field(default_factory=list)
|
||||
|
||||
def add_action(self, action: Action):
|
||||
self.trace.extend(parse_action(self.trace, action))
|
||||
|
||||
def add_observation(self, obs: Observation):
|
||||
self.trace.extend(parse_observation(self.trace, obs))
|
||||
|
||||
def concatenate(self, other: 'InvariantState'):
|
||||
self.trace.extend(other.trace)
|
||||
19
openhands/security/invariant/policies.py
Normal file
19
openhands/security/invariant/policies.py
Normal file
@@ -0,0 +1,19 @@
|
||||
DEFAULT_INVARIANT_POLICY = """from invariant.detectors import semgrep, secrets, CodeIssue
|
||||
|
||||
raise "Disallow secrets in bash commands [risk=medium]" if:
|
||||
(call: ToolCall)
|
||||
call is tool:cmd_run
|
||||
any(secrets(call.function.arguments.command))
|
||||
|
||||
raise "Vulnerability in python code [risk=medium]" if:
|
||||
(call: ToolCall)
|
||||
call is tool:ipython_run_cell
|
||||
semgrep_res := semgrep(call.function.arguments.code, lang="python")
|
||||
any(semgrep_res)
|
||||
|
||||
raise "Vulnerability in bash command [risk=medium]" if:
|
||||
(call: ToolCall)
|
||||
call is tool:cmd_run
|
||||
semgrep_res := semgrep(call.function.arguments.command, lang="bash")
|
||||
any(semgrep_res)
|
||||
"""
|
||||
5
openhands/security/options.py
Normal file
5
openhands/security/options.py
Normal file
@@ -0,0 +1,5 @@
|
||||
from openhands.security.invariant.analyzer import InvariantAnalyzer
|
||||
|
||||
SecurityAnalyzers = {
|
||||
'invariant': InvariantAnalyzer,
|
||||
}
|
||||
Reference in New Issue
Block a user