mirror of
https://github.com/All-Hands-AI/OpenHands.git
synced 2026-04-29 03:00:45 -04:00
Compare commits
24 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 532a284d5c | |||
| 43f6104967 | |||
| e619929909 | |||
| 93753ac2e0 | |||
| 38e19d214d | |||
| 19a4f1c3ec | |||
| 45a048f9e3 | |||
| 358d9cb3f4 | |||
| e6a2fd3fd4 | |||
| c2f308f397 | |||
| a1f1c802d9 | |||
| ad2237d7dd | |||
| aa0cd51967 | |||
| 081a1305f0 | |||
| 9912e28576 | |||
| b19a33ccad | |||
| 21e912d6fb | |||
| 0dd9b95dbe | |||
| aebb583779 | |||
| e249b920ff | |||
| d920a69f69 | |||
| a8ce888981 | |||
| e22ddc0dd6 | |||
| c370912f12 |
@@ -84,6 +84,10 @@ jobs:
|
||||
run: |
|
||||
python -m pip index versions openhands-ai > openhands_versions.txt
|
||||
OPENHANDS_VERSION=$(head -n 1 openhands_versions.txt | awk '{print $2}' | tr -d '()')
|
||||
# Ensure requirements.txt ends with newline before appending
|
||||
if [ -f requirements.txt ] && [ -s requirements.txt ]; then
|
||||
sed -i -e '$a\' requirements.txt
|
||||
fi
|
||||
echo "openhands-ai==${OPENHANDS_VERSION}" >> requirements.txt
|
||||
cat requirements.txt
|
||||
|
||||
|
||||
@@ -176,6 +176,7 @@ evaluation/gorilla/data
|
||||
evaluation/toolqa/data
|
||||
evaluation/scienceagentbench/benchmark
|
||||
evaluation/commit0_bench/repos
|
||||
evaluation/visualcodebench/
|
||||
|
||||
# openhands resolver
|
||||
output/
|
||||
|
||||
@@ -75,7 +75,7 @@ workspace_base = "./workspace"
|
||||
#run_as_openhands = true
|
||||
|
||||
# Runtime environment
|
||||
#runtime = "eventstream"
|
||||
#runtime = "docker"
|
||||
|
||||
# Name of the default agent
|
||||
#default_agent = "CodeActAgent"
|
||||
|
||||
@@ -0,0 +1,674 @@
|
||||
from collections import Counter
|
||||
from copy import deepcopy
|
||||
from difflib import SequenceMatcher
|
||||
from io import BytesIO
|
||||
|
||||
from bs4 import BeautifulSoup, Comment, NavigableString, Tag
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
from colormath.color_conversions import convert_color
|
||||
from colormath.color_diff import delta_e_cie2000
|
||||
from colormath.color_objects import LabColor, sRGBColor
|
||||
from PIL import Image, ImageChops, ImageColor
|
||||
from scipy.optimize import linear_sum_assignment
|
||||
from transformers import CLIPModel, CLIPProcessor
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
|
||||
|
||||
def calculate_similarity(block1, block2):
|
||||
"""Calculate text similarity between two blocks using SequenceMatcher."""
|
||||
text_similarity = SequenceMatcher(None, block1['text'], block2['text']).ratio()
|
||||
return text_similarity
|
||||
|
||||
|
||||
def adjust_cost_for_context(cost_matrix, consecutive_bonus=1.0, window_size=20):
|
||||
"""Adjust cost matrix by considering context similarity."""
|
||||
if window_size <= 0:
|
||||
return cost_matrix
|
||||
|
||||
n, m = cost_matrix.shape
|
||||
adjusted_cost_matrix = np.copy(cost_matrix)
|
||||
|
||||
for i in range(n):
|
||||
for j in range(m):
|
||||
if adjusted_cost_matrix[i][j] >= -0.5:
|
||||
continue
|
||||
nearby_matrix = cost_matrix[
|
||||
max(0, i - window_size) : min(n, i + window_size + 1),
|
||||
max(0, j - window_size) : min(m, j + window_size + 1),
|
||||
]
|
||||
flattened_array = nearby_matrix.flatten()
|
||||
sorted_array = np.sort(flattened_array)[::-1]
|
||||
sorted_array = np.delete(
|
||||
sorted_array, np.where(sorted_array == cost_matrix[i, j])[0][0]
|
||||
)
|
||||
top_k_elements = sorted_array[-window_size * 2 :]
|
||||
bonus = consecutive_bonus * np.sum(top_k_elements)
|
||||
adjusted_cost_matrix[i][j] += bonus
|
||||
return adjusted_cost_matrix
|
||||
|
||||
|
||||
def create_cost_matrix(A, B):
|
||||
"""Create cost matrix for block matching."""
|
||||
n = len(A)
|
||||
m = len(B)
|
||||
cost_matrix = np.zeros((n, m))
|
||||
for i in range(n):
|
||||
for j in range(m):
|
||||
cost_matrix[i, j] = -calculate_similarity(A[i], B[j])
|
||||
return cost_matrix
|
||||
|
||||
|
||||
def calculate_distance_max_1d(x1, y1, x2, y2):
|
||||
"""Calculate maximum 1D distance between points."""
|
||||
return max(abs(x2 - x1), abs(y2 - y1))
|
||||
|
||||
|
||||
def calculate_ratio(h1, h2):
|
||||
"""Calculate ratio between two heights."""
|
||||
return max(h1, h2) / min(h1, h2)
|
||||
|
||||
|
||||
def rgb_to_lab(rgb):
|
||||
"""Convert RGB color to Lab color space."""
|
||||
rgb_color = sRGBColor(rgb[0], rgb[1], rgb[2], is_upscaled=True)
|
||||
lab_color = convert_color(rgb_color, LabColor)
|
||||
return lab_color
|
||||
|
||||
|
||||
def color_similarity_ciede2000(rgb1, rgb2):
|
||||
"""Calculate color similarity using CIEDE2000 formula."""
|
||||
lab1 = rgb_to_lab(rgb1)
|
||||
lab2 = rgb_to_lab(rgb2)
|
||||
delta_e = delta_e_cie2000(lab1, lab2)
|
||||
similarity = max(0, 1 - (delta_e / 100))
|
||||
return similarity
|
||||
|
||||
|
||||
def merge_blocks_wo_check(block1, block2):
|
||||
"""Merge two blocks without additional checks."""
|
||||
merged_text = block1['text'] + ' ' + block2['text']
|
||||
x_min = min(block1['bbox'][0], block2['bbox'][0])
|
||||
y_min = min(block1['bbox'][1], block2['bbox'][1])
|
||||
x_max = max(
|
||||
block1['bbox'][0] + block1['bbox'][2], block2['bbox'][0] + block2['bbox'][2]
|
||||
)
|
||||
y_max = max(
|
||||
block1['bbox'][1] + block1['bbox'][3], block2['bbox'][1] + block2['bbox'][3]
|
||||
)
|
||||
merged_bbox = (x_min, y_min, x_max - x_min, y_max - y_min)
|
||||
merged_color = tuple(
|
||||
(color1 + color2) // 2
|
||||
for color1, color2 in zip(block1['color'], block2['color'])
|
||||
)
|
||||
return {'text': merged_text, 'bbox': merged_bbox, 'color': merged_color}
|
||||
|
||||
|
||||
def find_maximum_matching(A, B, consecutive_bonus, window_size):
|
||||
"""Find maximum matching between two sets of blocks."""
|
||||
cost_matrix = create_cost_matrix(A, B)
|
||||
cost_matrix = adjust_cost_for_context(cost_matrix, consecutive_bonus, window_size)
|
||||
row_ind, col_ind = linear_sum_assignment(cost_matrix)
|
||||
current_cost = cost_matrix[row_ind, col_ind].tolist()
|
||||
return list(zip(row_ind, col_ind)), current_cost, cost_matrix
|
||||
|
||||
|
||||
def remove_indices(lst, indices):
|
||||
"""Remove indices from list in reverse order."""
|
||||
for index in sorted(indices, reverse=True):
|
||||
if index < len(lst):
|
||||
lst.pop(index)
|
||||
return lst
|
||||
|
||||
|
||||
def merge_blocks_by_list(blocks, merge_list):
|
||||
"""Merge blocks according to merge list."""
|
||||
pop_list = []
|
||||
while merge_list:
|
||||
i = merge_list[0][0]
|
||||
j = merge_list[0][1]
|
||||
blocks[i] = merge_blocks_wo_check(blocks[i], blocks[j])
|
||||
pop_list.append(j)
|
||||
merge_list.pop(0)
|
||||
if merge_list:
|
||||
new_merge_list = []
|
||||
for k in range(len(merge_list)):
|
||||
if (
|
||||
merge_list[k][0] != i
|
||||
and merge_list[k][1] != i
|
||||
and merge_list[k][0] != j
|
||||
and merge_list[k][1] != j
|
||||
):
|
||||
new_merge_list.append(merge_list[k])
|
||||
merge_list = new_merge_list
|
||||
remove_indices(blocks, pop_list)
|
||||
return blocks
|
||||
|
||||
|
||||
def difference_of_means(list1, list2):
|
||||
"""Calculate difference of means between two lists."""
|
||||
counter1 = Counter(list1)
|
||||
counter2 = Counter(list2)
|
||||
|
||||
for element in set(list1) & set(list2):
|
||||
common_count = min(counter1[element], counter2[element])
|
||||
counter1[element] -= common_count
|
||||
counter2[element] -= common_count
|
||||
|
||||
unique_list1 = [item for item in counter1.elements()]
|
||||
unique_list2 = [item for item in counter2.elements()]
|
||||
|
||||
mean_list1 = sum(unique_list1) / len(unique_list1) if unique_list1 else 0
|
||||
mean_list2 = sum(unique_list2) / len(unique_list2) if unique_list2 else 0
|
||||
|
||||
if mean_list1 - mean_list2 > 0:
|
||||
if min(unique_list1) > min(unique_list2):
|
||||
return mean_list1 - mean_list2
|
||||
return 0.0
|
||||
return mean_list1 - mean_list2
|
||||
|
||||
|
||||
def find_possible_merge(A, B, consecutive_bonus, window_size, debug=False):
|
||||
"""Find possible merges between blocks."""
|
||||
merge_bonus = 0.0
|
||||
merge_windows = 1
|
||||
|
||||
def sortFn(value):
|
||||
return value[2]
|
||||
|
||||
while True:
|
||||
A_changed = False
|
||||
B_changed = False
|
||||
|
||||
matching, current_cost, cost_matrix = find_maximum_matching(
|
||||
A, B, merge_bonus, merge_windows
|
||||
)
|
||||
|
||||
if len(A) >= 2:
|
||||
merge_list = []
|
||||
for i in range(len(A) - 1):
|
||||
new_A = deepcopy(A)
|
||||
new_A[i] = merge_blocks_wo_check(new_A[i], new_A[i + 1])
|
||||
new_A.pop(i + 1)
|
||||
updated_matching, updated_cost, _ = find_maximum_matching(
|
||||
new_A, B, merge_bonus, merge_windows
|
||||
)
|
||||
diff = difference_of_means(current_cost, updated_cost)
|
||||
if diff > 0.05:
|
||||
merge_list.append([i, i + 1, diff])
|
||||
|
||||
merge_list.sort(key=sortFn, reverse=True)
|
||||
if merge_list:
|
||||
A_changed = True
|
||||
A = merge_blocks_by_list(A, merge_list)
|
||||
matching, current_cost, cost_matrix = find_maximum_matching(
|
||||
A, B, merge_bonus, merge_windows
|
||||
)
|
||||
|
||||
if len(B) >= 2:
|
||||
merge_list = []
|
||||
for i in range(len(B) - 1):
|
||||
new_B = deepcopy(B)
|
||||
new_B[i] = merge_blocks_wo_check(new_B[i], new_B[i + 1])
|
||||
new_B.pop(i + 1)
|
||||
updated_matching, updated_cost, _ = find_maximum_matching(
|
||||
A, new_B, merge_bonus, merge_windows
|
||||
)
|
||||
diff = difference_of_means(current_cost, updated_cost)
|
||||
if diff > 0.05:
|
||||
merge_list.append([i, i + 1, diff])
|
||||
|
||||
merge_list.sort(key=sortFn, reverse=True)
|
||||
if merge_list:
|
||||
B_changed = True
|
||||
B = merge_blocks_by_list(B, merge_list)
|
||||
matching, current_cost, cost_matrix = find_maximum_matching(
|
||||
A, B, merge_bonus, merge_windows
|
||||
)
|
||||
|
||||
if not A_changed and not B_changed:
|
||||
break
|
||||
|
||||
matching, _, _ = find_maximum_matching(A, B, consecutive_bonus, window_size)
|
||||
return A, B, matching
|
||||
|
||||
|
||||
def merge_blocks_by_bbox(blocks):
|
||||
"""Merge blocks with same bounding box."""
|
||||
merged_blocks = {}
|
||||
for block in blocks:
|
||||
bbox = tuple(block['bbox'])
|
||||
if bbox in merged_blocks:
|
||||
existing_block = merged_blocks[bbox]
|
||||
existing_block['text'] += ' ' + block['text']
|
||||
existing_block['color'] = [
|
||||
(ec + c) / 2 for ec, c in zip(existing_block['color'], block['color'])
|
||||
]
|
||||
else:
|
||||
merged_blocks[bbox] = block
|
||||
return list(merged_blocks.values())
|
||||
|
||||
|
||||
def mask_bounding_boxes_with_inpainting(image, bounding_boxes):
|
||||
"""Mask bounding boxes in image using inpainting."""
|
||||
image_cv = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
|
||||
mask = np.zeros(image_cv.shape[:2], dtype=np.uint8)
|
||||
height, width = image_cv.shape[:2]
|
||||
|
||||
for bbox in bounding_boxes:
|
||||
x_ratio, y_ratio, w_ratio, h_ratio = bbox
|
||||
x = int(x_ratio * width)
|
||||
y = int(y_ratio * height)
|
||||
w = int(w_ratio * width)
|
||||
h = int(h_ratio * height)
|
||||
mask[y : y + h, x : x + w] = 255
|
||||
|
||||
inpainted_image = cv2.inpaint(image_cv, mask, 3, cv2.INPAINT_TELEA)
|
||||
return Image.fromarray(cv2.cvtColor(inpainted_image, cv2.COLOR_BGR2RGB))
|
||||
|
||||
|
||||
def rescale_and_mask(image, blocks):
|
||||
"""Rescale image and mask blocks."""
|
||||
if blocks:
|
||||
image = mask_bounding_boxes_with_inpainting(image, blocks)
|
||||
|
||||
width, height = image.size
|
||||
if width < height:
|
||||
new_size = (width, width)
|
||||
else:
|
||||
new_size = (height, height)
|
||||
|
||||
return image.resize(new_size, Image.LANCZOS)
|
||||
|
||||
|
||||
def calculate_clip_similarity(image1, image2, blocks1, blocks2):
|
||||
"""Calculate CLIP similarity between two images."""
|
||||
model = CLIPModel.from_pretrained('openai/clip-vit-base-patch32')
|
||||
processor = CLIPProcessor.from_pretrained('openai/clip-vit-base-patch32')
|
||||
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||
model = model.to(device)
|
||||
|
||||
# Mask and preprocess images
|
||||
image1_masked = rescale_and_mask(image1, [block['bbox'] for block in blocks1])
|
||||
image2_masked = rescale_and_mask(image2, [block['bbox'] for block in blocks2])
|
||||
inputs = processor(
|
||||
images=[image1_masked, image2_masked], return_tensors='pt', padding=True
|
||||
)
|
||||
inputs = {k: v.to(device) for k, v in inputs.items()}
|
||||
|
||||
# Calculate features and similarity
|
||||
with torch.no_grad():
|
||||
image_features = model.get_image_features(**inputs)
|
||||
image_features1 = image_features[0].unsqueeze(0)
|
||||
image_features2 = image_features[1].unsqueeze(0)
|
||||
image_features1 /= image_features1.norm(dim=-1, keepdim=True)
|
||||
image_features2 /= image_features2.norm(dim=-1, keepdim=True)
|
||||
similarity = (image_features1 @ image_features2.T).item()
|
||||
|
||||
return similarity
|
||||
|
||||
|
||||
def rgb_to_hex(rgb):
|
||||
"""Convert an RGB tuple to hexadecimal format."""
|
||||
return '{:02X}{:02X}{:02X}'.format(*rgb)
|
||||
|
||||
|
||||
class ColorPool:
|
||||
def __init__(self, offset=0):
|
||||
color_values = list(range(10, 251, 16))
|
||||
color_list = [((r + offset) % 256, (g + offset) % 256, (b + offset) % 256)
|
||||
for r in color_values for g in color_values for b in color_values]
|
||||
self.color_pool = [rgb_to_hex(color) for color in color_list]
|
||||
|
||||
def pop_color(self):
|
||||
if self.color_pool:
|
||||
return self.color_pool.pop()
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def process_html_str(html_str, offset=0):
|
||||
"""Process HTML string to assign unique colors to text elements."""
|
||||
soup = BeautifulSoup(html_str, 'html.parser')
|
||||
|
||||
def update_style(element, property_name, value):
|
||||
important_value = f"{value} !important"
|
||||
styles = element.attrs.get('style', '').split(';')
|
||||
updated_styles = [s for s in styles if not s.strip().startswith(property_name) and len(s.strip()) > 0]
|
||||
updated_styles.append(f"{property_name}: {important_value}")
|
||||
element['style'] = '; '.join(updated_styles).strip()
|
||||
|
||||
# Set background color of all elements to transparent white
|
||||
for element in soup.find_all(True):
|
||||
update_style(element, 'background-color', 'rgba(255, 255, 255, 0.0)')
|
||||
|
||||
color_pool = ColorPool(offset)
|
||||
text_tags = ['p', 'h1', 'h2', 'h3', 'h4', 'h5', 'h6', 'div', 'span', 'a', 'b', 'li',
|
||||
'table', 'td', 'th', 'button', 'footer', 'header', 'figcaption']
|
||||
|
||||
for tag in soup.find_all(text_tags):
|
||||
color = f"#{color_pool.pop_color()}"
|
||||
update_style(tag, 'color', color)
|
||||
update_style(tag, 'opacity', '1.0')
|
||||
|
||||
return str(soup)
|
||||
|
||||
|
||||
def similar(n1, n2):
|
||||
"""Check if two numbers are similar within a threshold."""
|
||||
return abs(n1 - n2) <= 8
|
||||
|
||||
|
||||
def find_different_pixels(image1, image2):
|
||||
"""Find pixels that differ between two images."""
|
||||
if image1.size != image2.size:
|
||||
logger.warning("Images are not the same size")
|
||||
return None
|
||||
|
||||
image1 = image1.convert('RGB')
|
||||
image2 = image2.convert('RGB')
|
||||
pixels1 = image1.load()
|
||||
pixels2 = image2.load()
|
||||
different_pixels = []
|
||||
|
||||
for x in range(image1.size[0]):
|
||||
for y in range(image1.size[1]):
|
||||
r1, g1, b1 = pixels1[x, y]
|
||||
r2, g2, b2 = pixels2[x, y]
|
||||
if similar((r1 + 50) % 256, r2) and similar((g1 + 50) % 256, g2) and similar((b1 + 50) % 256, b2):
|
||||
different_pixels.append((y, x))
|
||||
|
||||
return np.stack(different_pixels) if different_pixels else None
|
||||
|
||||
|
||||
def extract_text_with_color(html_str):
|
||||
"""Extract text and color information from HTML string."""
|
||||
def get_color(tag):
|
||||
if 'style' in tag.attrs:
|
||||
styles = tag['style'].split(';')
|
||||
color_style = [s for s in styles if 'color' in s and 'background-color' not in s]
|
||||
if color_style:
|
||||
color = color_style[-1].split(':')[1].strip().replace(" !important", "")
|
||||
if color[0] == "#":
|
||||
return color
|
||||
else:
|
||||
try:
|
||||
if color.startswith('rgb'):
|
||||
color = tuple(map(int, color[4:-1].split(',')))
|
||||
else:
|
||||
color = ImageColor.getrgb(color)
|
||||
return '#{:02x}{:02x}{:02x}'.format(*color)
|
||||
except ValueError:
|
||||
logger.warning(f"Unable to identify or convert color: {color}")
|
||||
return None
|
||||
return None
|
||||
|
||||
def extract_text_recursive(element, parent_color='#000000'):
|
||||
if isinstance(element, Comment):
|
||||
return None
|
||||
elif isinstance(element, NavigableString):
|
||||
text = element.strip()
|
||||
return (text, parent_color) if text else None
|
||||
elif isinstance(element, Tag):
|
||||
current_color = get_color(element) or parent_color
|
||||
children_texts = filter(None, [extract_text_recursive(child, current_color)
|
||||
for child in element.children])
|
||||
return list(children_texts)
|
||||
|
||||
soup = BeautifulSoup(html_str, 'html.parser')
|
||||
body = soup.body
|
||||
return extract_text_recursive(body) if body else []
|
||||
|
||||
|
||||
def flatten_tree(tree):
|
||||
"""Flatten a nested tree structure into a list."""
|
||||
flat_list = []
|
||||
def flatten(node):
|
||||
if isinstance(node, list):
|
||||
for item in node:
|
||||
flatten(item)
|
||||
else:
|
||||
flat_list.append(node)
|
||||
flatten(tree)
|
||||
return flat_list
|
||||
|
||||
|
||||
def get_blocks_from_image_diff_pixels(image, html_text_color_tree, different_pixels):
|
||||
"""Extract text blocks from image using color differences."""
|
||||
image_cv = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
|
||||
x_w = image_cv.shape[0]
|
||||
y_w = image_cv.shape[1]
|
||||
|
||||
def hex_to_bgr(hex_color):
|
||||
hex_color = hex_color.lstrip('#')
|
||||
rgb = tuple(int(hex_color[i:i+2], 16) for i in (0, 2, 4))
|
||||
return rgb[::-1]
|
||||
|
||||
def get_intersect(arr1, arr2):
|
||||
arr1_reshaped = arr1.view([('', arr1.dtype)] * arr1.shape[1])
|
||||
arr2_reshaped = arr2.view([('', arr2.dtype)] * arr2.shape[1])
|
||||
common_rows = np.intersect1d(arr1_reshaped, arr2_reshaped)
|
||||
return common_rows.view(arr1.dtype).reshape(-1, arr1.shape[1])
|
||||
|
||||
blocks = []
|
||||
for item in html_text_color_tree:
|
||||
try:
|
||||
color = np.array(hex_to_bgr(item[1]), dtype="uint8")
|
||||
except:
|
||||
continue
|
||||
|
||||
lower = color - 4
|
||||
upper = color + 4
|
||||
mask = cv2.inRange(image_cv, lower, upper)
|
||||
coords = np.column_stack(np.where(mask > 0))
|
||||
coords = get_intersect(coords, different_pixels)
|
||||
|
||||
if coords.size == 0:
|
||||
continue
|
||||
|
||||
x_min, y_min = np.min(coords, axis=0)
|
||||
x_max, y_max = np.max(coords, axis=0)
|
||||
|
||||
# Get average color from original image
|
||||
color_coords = coords.copy()
|
||||
color_coords = color_coords[color_coords[:, 0] <= x_max]
|
||||
color_coords = color_coords[color_coords[:, 1] <= y_max]
|
||||
colors = [image_cv[x, y] for x, y in color_coords]
|
||||
avg_color = tuple(map(int, np.mean(colors, axis=0)))[::-1] # Convert BGR to RGB
|
||||
|
||||
blocks.append({
|
||||
'text': item[0].lower(),
|
||||
'bbox': (y_min / y_w, x_min / x_w, (y_max - y_min + 1) / y_w, (x_max - x_min + 1) / x_w),
|
||||
'color': avg_color
|
||||
})
|
||||
|
||||
return blocks
|
||||
|
||||
|
||||
def get_blocks_from_html(html_str, image1):
|
||||
"""Extract text blocks from HTML and image."""
|
||||
# Process HTML with two different color offsets
|
||||
html_str_1 = process_html_str(html_str, offset=0)
|
||||
html_str_2 = process_html_str(html_str, offset=50)
|
||||
|
||||
# Render both HTML versions to images
|
||||
# TODO: Screenshot html_str_2
|
||||
filter_color = (255, 0, 0)
|
||||
image2 = Image.new("RGB", image1.size, filter_color)
|
||||
|
||||
|
||||
# Find pixels that differ between the two rendered images
|
||||
different_pixels = find_different_pixels(image1, image2)
|
||||
if different_pixels is None:
|
||||
logger.warning("Unable to get pixels with different colors")
|
||||
return []
|
||||
|
||||
# Extract text and color information from HTML
|
||||
html_text_color_tree = flatten_tree(extract_text_with_color(html_str_1))
|
||||
try:
|
||||
blocks = get_blocks_from_image_diff_pixels(image1, html_text_color_tree, different_pixels)
|
||||
except Exception as e:
|
||||
logger.warning(f"Unable to get blocks: {e}")
|
||||
return []
|
||||
|
||||
return blocks
|
||||
|
||||
|
||||
def evaluate(task, generated_img):
|
||||
"""Evaluate generated image against reference image using multiple metrics."""
|
||||
# Load reference image
|
||||
post_image = task['post_image']
|
||||
|
||||
# Extract blocks from HTML and images
|
||||
post_blocks = get_blocks_from_html(task['post_html'], post_image)
|
||||
gen_blocks = get_blocks_from_html(task['gen_html'], generated_img)
|
||||
|
||||
print("block details", post_blocks, gen_blocks)
|
||||
if not post_blocks or not gen_blocks:
|
||||
# Fallback to basic CLIP and pixel comparison if no blocks available
|
||||
clip_score = calculate_clip_similarity(post_image, generated_img, [], [])
|
||||
logger.info(f'CLIP similarity score: {clip_score}')
|
||||
|
||||
# Pixel comparison
|
||||
diff = ImageChops.difference(generated_img, post_image)
|
||||
pixel_match = not diff.getbbox()
|
||||
logger.info(
|
||||
f"Pixel difference analysis: {'No difference' if pixel_match else 'Differences found'}"
|
||||
)
|
||||
|
||||
return clip_score > 0.95 or pixel_match
|
||||
|
||||
# Merge blocks with same bounding boxes
|
||||
post_blocks = merge_blocks_by_bbox(post_blocks)
|
||||
gen_blocks = merge_blocks_by_bbox(gen_blocks)
|
||||
|
||||
# Find optimal block matching
|
||||
consecutive_bonus, window_size = 0.1, 1
|
||||
gen_blocks_m, post_blocks_m, matching = find_possible_merge(
|
||||
gen_blocks, deepcopy(post_blocks), consecutive_bonus, window_size
|
||||
)
|
||||
|
||||
# Filter matches with low similarity
|
||||
filtered_matching = []
|
||||
for i, j in matching:
|
||||
text_similarity = calculate_similarity(gen_blocks_m[i], post_blocks_m[j])
|
||||
if text_similarity >= 0.5:
|
||||
filtered_matching.append([i, j, text_similarity])
|
||||
matching = filtered_matching
|
||||
|
||||
if not matching:
|
||||
logger.warning('No matching blocks found')
|
||||
clip_score = calculate_clip_similarity(
|
||||
post_image, generated_img, gen_blocks, post_blocks
|
||||
)
|
||||
return clip_score > 0.95
|
||||
|
||||
# Calculate metrics for matched blocks
|
||||
indices1 = [item[0] for item in matching]
|
||||
indices2 = [item[1] for item in matching]
|
||||
|
||||
# Calculate unmatched areas
|
||||
unmatched_area_1 = sum(
|
||||
block['bbox'][2] * block['bbox'][3]
|
||||
for i, block in enumerate(gen_blocks_m)
|
||||
if i not in indices1
|
||||
)
|
||||
unmatched_area_2 = sum(
|
||||
block['bbox'][2] * block['bbox'][3]
|
||||
for j, block in enumerate(post_blocks_m)
|
||||
if j not in indices2
|
||||
)
|
||||
total_unmatched_area = unmatched_area_1 + unmatched_area_2
|
||||
|
||||
# Calculate metrics for matched blocks
|
||||
matched_areas = []
|
||||
text_scores = []
|
||||
position_scores = []
|
||||
color_scores = []
|
||||
|
||||
for i, j, text_similarity in matching:
|
||||
# Area
|
||||
block_area = (
|
||||
gen_blocks_m[i]['bbox'][2] * gen_blocks_m[i]['bbox'][3]
|
||||
+ post_blocks_m[j]['bbox'][2] * post_blocks_m[j]['bbox'][3]
|
||||
)
|
||||
matched_areas.append(block_area)
|
||||
|
||||
# Position similarity
|
||||
position_similarity = 1 - calculate_distance_max_1d(
|
||||
gen_blocks_m[i]['bbox'][0] + gen_blocks_m[i]['bbox'][2] / 2,
|
||||
gen_blocks_m[i]['bbox'][1] + gen_blocks_m[i]['bbox'][3] / 2,
|
||||
post_blocks_m[j]['bbox'][0] + post_blocks_m[j]['bbox'][2] / 2,
|
||||
post_blocks_m[j]['bbox'][1] + post_blocks_m[j]['bbox'][3] / 2,
|
||||
)
|
||||
|
||||
# Color similarity
|
||||
color_similarity = color_similarity_ciede2000(
|
||||
gen_blocks_m[i]['color'], post_blocks_m[j]['color']
|
||||
)
|
||||
|
||||
text_scores.append(text_similarity)
|
||||
position_scores.append(position_similarity)
|
||||
color_scores.append(color_similarity)
|
||||
|
||||
# Calculate final scores
|
||||
total_area = sum(matched_areas) + total_unmatched_area
|
||||
size_score = sum(matched_areas) / total_area if total_area > 0 else 0
|
||||
text_score = np.mean(text_scores) if text_scores else 0
|
||||
position_score = np.mean(position_scores) if position_scores else 0
|
||||
color_score = np.mean(color_scores) if color_scores else 0
|
||||
clip_score = calculate_clip_similarity(
|
||||
post_image, generated_img, gen_blocks, post_blocks
|
||||
)
|
||||
|
||||
# Combine scores with equal weights
|
||||
final_score = 0.2 * (
|
||||
size_score + text_score + position_score + color_score + clip_score
|
||||
)
|
||||
|
||||
logger.info('Evaluation scores:')
|
||||
logger.info(f'- Size score: {size_score:.3f}')
|
||||
logger.info(f'- Text score: {text_score:.3f}')
|
||||
logger.info(f'- Position score: {position_score:.3f}')
|
||||
logger.info(f'- Color score: {color_score:.3f}')
|
||||
logger.info(f'- CLIP score: {clip_score:.3f}')
|
||||
logger.info(f'- Final score: {final_score:.3f}')
|
||||
|
||||
return final_score > 0.8 # Consider it a match if final score > 80%
|
||||
|
||||
|
||||
def png_to_bytes(png):
|
||||
buffer = BytesIO()
|
||||
png.save(buffer, format='PNG')
|
||||
image_bytes = buffer.getvalue()
|
||||
return image_bytes
|
||||
|
||||
|
||||
def bytes_to_image(image_bytes):
|
||||
"""Convert bytes to a Pillow Image object."""
|
||||
return Image.open(BytesIO(image_bytes))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
first_image = Image.open('./evaluation/visualcodebench/data/1/post.png')
|
||||
image = Image.open('./evaluation/visualcodebench/data/1/prev.png')
|
||||
|
||||
|
||||
html_file = open('./evaluation/visualcodebench/data/1/post/index.html', 'r')
|
||||
first_html = html_file.read()
|
||||
html_file.close()
|
||||
|
||||
html_file = open('./evaluation/visualcodebench/data/1/prev/index.html', 'r')
|
||||
gen_html = html_file.read()
|
||||
html_file.close()
|
||||
|
||||
|
||||
|
||||
sample = {'post_image': first_image, "post_html": first_html, "gen_html": gen_html}
|
||||
|
||||
|
||||
|
||||
evaluate(sample, image)
|
||||
|
||||
@@ -0,0 +1,97 @@
|
||||
import base64
|
||||
import os
|
||||
from io import BytesIO
|
||||
|
||||
import pandas as pd
|
||||
from huggingface_hub import snapshot_download
|
||||
from PIL import PngImagePlugin
|
||||
from tqdm import tqdm
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
|
||||
REPO_DOWNLOAD_DIR = (
|
||||
'./evaluation/visualcodebench/' # Directory to store the downloaded repository
|
||||
)
|
||||
|
||||
|
||||
def download_repository():
|
||||
"""
|
||||
Download the entire repository from Hugging Face Hub.
|
||||
This function clones the repository into REPO_DOWNLOAD_DIR.
|
||||
"""
|
||||
repo_id = 'rvmalhot/VisualCodeBench'
|
||||
try:
|
||||
logger.info(f"Downloading repository '{repo_id}'...")
|
||||
snapshot_download(
|
||||
repo_id=repo_id,
|
||||
local_dir=REPO_DOWNLOAD_DIR,
|
||||
repo_type='dataset',
|
||||
ignore_patterns=None, # Download all files
|
||||
)
|
||||
logger.info(f"Repository downloaded to '{REPO_DOWNLOAD_DIR}'.")
|
||||
except Exception as e:
|
||||
logger.error(f"Error downloading repository '{repo_id}': {e}")
|
||||
raise e
|
||||
|
||||
|
||||
def format_task_dict(example):
|
||||
instance_id = example['id']
|
||||
prev_remote_path = os.path.join(REPO_DOWNLOAD_DIR, f'data/{instance_id}/prev')
|
||||
post_remote_path = os.path.join(REPO_DOWNLOAD_DIR, f'data/{instance_id}/post')
|
||||
|
||||
# Check if 'prev' and 'post' directories exist
|
||||
prev_exists = os.path.exists(prev_remote_path)
|
||||
post_exists = os.path.exists(post_remote_path)
|
||||
|
||||
if prev_exists and post_exists:
|
||||
skip = False
|
||||
else:
|
||||
skip = True
|
||||
|
||||
task = {
|
||||
'instance_id': instance_id,
|
||||
'prev_image': example['prev_image'],
|
||||
'post_image': example['post_image'],
|
||||
'changes': example['changes'],
|
||||
'prev_code_files': example['prev_code_files'],
|
||||
'post_code_files': example['post_code_files'],
|
||||
'skip': skip,
|
||||
}
|
||||
|
||||
return task
|
||||
|
||||
|
||||
def prepare_visualcodebench(dataset):
|
||||
logger.info('Processing dataset')
|
||||
dataset_processed = []
|
||||
for example in tqdm(dataset['train']):
|
||||
formatted_example = format_task_dict(example)
|
||||
if formatted_example['skip']:
|
||||
continue
|
||||
del formatted_example['skip']
|
||||
dataset_processed.append(formatted_example)
|
||||
|
||||
return pd.DataFrame(dataset_processed)
|
||||
|
||||
|
||||
def pil_image_to_base64(image: PngImagePlugin.PngImageFile) -> str:
|
||||
"""
|
||||
Converts a PIL image to a Base64-encoded string.
|
||||
|
||||
Parameters:
|
||||
- image (PngImagePlugin.PngImageFile): The PIL image to convert.
|
||||
|
||||
Returns:
|
||||
- str: The Base64-encoded string of the image.
|
||||
"""
|
||||
if not isinstance(image, PngImagePlugin.PngImageFile):
|
||||
raise ValueError(
|
||||
'The provided image is not a PIL.PngImagePlugin.PngImageFile instance.'
|
||||
)
|
||||
|
||||
buffered = BytesIO()
|
||||
image.save(buffered, format='PNG')
|
||||
img_bytes = buffered.getvalue()
|
||||
img_base64 = base64.b64encode(img_bytes).decode('utf-8')
|
||||
base64_with_prefix = f'data:image/png;base64,{img_base64}'
|
||||
return [base64_with_prefix]
|
||||
@@ -0,0 +1,247 @@
|
||||
# FILE: run_infer.py
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import shutil
|
||||
import tempfile
|
||||
from functools import partial
|
||||
|
||||
import pandas as pd
|
||||
from datasets import load_dataset
|
||||
|
||||
# from evaluation.benchmarks.visualcodebench.eval import capture_screenshot
|
||||
from evaluation.benchmarks.visualcodebench.prepare import (
|
||||
REPO_DOWNLOAD_DIR,
|
||||
download_repository,
|
||||
pil_image_to_base64,
|
||||
prepare_visualcodebench,
|
||||
)
|
||||
from evaluation.utils.shared import (
|
||||
EvalMetadata,
|
||||
assert_and_raise,
|
||||
codeact_user_response,
|
||||
make_metadata,
|
||||
prepare_dataset,
|
||||
reset_logger_for_multiprocessing,
|
||||
run_evaluation,
|
||||
)
|
||||
from openhands.controller.state.state import State
|
||||
from openhands.core.config import (
|
||||
AppConfig,
|
||||
SandboxConfig,
|
||||
get_llm_config_arg,
|
||||
)
|
||||
from openhands.core.config.utils import parse_arguments
|
||||
from openhands.core.logger import openhands_logger as logger # Import OpenHands logger
|
||||
from openhands.core.main import create_runtime, run_controller
|
||||
from openhands.events.action.commands import CmdRunAction
|
||||
from openhands.events.action.message import MessageAction
|
||||
from openhands.events.observation.commands import CmdOutputObservation
|
||||
from openhands.runtime.base import Runtime
|
||||
from openhands.utils.async_utils import call_async_from_sync
|
||||
|
||||
# Define workspace and output directories
|
||||
WORKSPACE_DIR = './workspace'
|
||||
|
||||
FAKE_RESPONSES = {
|
||||
'CodeActAgent': partial(codeact_user_response, encapsulate_solution=True),
|
||||
}
|
||||
|
||||
|
||||
def get_config(
|
||||
metadata: EvalMetadata,
|
||||
) -> AppConfig:
|
||||
config = AppConfig(
|
||||
default_agent=metadata.agent_class,
|
||||
run_as_openhands=False,
|
||||
runtime='eventstream',
|
||||
max_iterations=metadata.max_iterations,
|
||||
sandbox=SandboxConfig(
|
||||
base_container_image='python:3.12-bookworm',
|
||||
enable_auto_lint=True,
|
||||
use_host_network=False,
|
||||
),
|
||||
# do not mount workspace
|
||||
workspace_base=None,
|
||||
workspace_mount_path=None,
|
||||
)
|
||||
config.set_llm_config(metadata.llm_config)
|
||||
return config
|
||||
|
||||
|
||||
def initialize_runtime(
|
||||
runtime: Runtime,
|
||||
instance: pd.Series, # this argument is not required
|
||||
):
|
||||
"""Initialize the runtime for the agent.
|
||||
|
||||
This function is called before the runtime is used to run the agent.
|
||||
"""
|
||||
logger.info('-' * 30)
|
||||
logger.info('BEGIN Runtime Initialization Fn')
|
||||
logger.info('-' * 30)
|
||||
workspace_dir_name = instance['instance_id']
|
||||
obs: CmdOutputObservation
|
||||
|
||||
action = CmdRunAction(command='mkdir -p /workspace/{workspace_dir_name}')
|
||||
action.timeout = 600
|
||||
logger.info(action, extra={'msg_type': 'ACTION'})
|
||||
obs = runtime.run_action(action)
|
||||
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
|
||||
assert_and_raise(
|
||||
obs.exit_code == 0,
|
||||
f'Failed to create /workspace/{workspace_dir_name}: {str(obs)}',
|
||||
)
|
||||
|
||||
file_path = REPO_DOWNLOAD_DIR + f'data/{workspace_dir_name}/prev/index.html'
|
||||
runtime.copy_to(file_path, f'/workspace/{workspace_dir_name}')
|
||||
logger.info(f'Copied code file for instance {workspace_dir_name}')
|
||||
|
||||
action = CmdRunAction(command=f'cd /workspace/{workspace_dir_name}')
|
||||
action.timeout = 600
|
||||
logger.info(action, extra={'msg_type': 'ACTION'})
|
||||
obs = runtime.run_action(action)
|
||||
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
|
||||
assert_and_raise(
|
||||
obs.exit_code == 0,
|
||||
f'Failed to cd to /workspace/{workspace_dir_name}: {str(obs)}',
|
||||
)
|
||||
|
||||
logger.info('-' * 30)
|
||||
logger.info('END Runtime Initialization Fn')
|
||||
logger.info('-' * 30)
|
||||
|
||||
|
||||
def complete_runtime(
|
||||
runtime: Runtime,
|
||||
instance: pd.Series, # this argument is not required, but it is used to get the workspace_dir_name
|
||||
) -> str:
|
||||
# TODO: extract edited HTML file from agent workspace
|
||||
# temp_zip = runtime.copy_from(f'/workspace/{instance.instance_id}')
|
||||
# file_name = f'/workspace/{instance.instance_id}/index.html'
|
||||
# with zipfile.ZipFile(temp_zip, 'r') as zip_ref:
|
||||
# if file_name in zip_ref.namelist():
|
||||
# with zip_ref.open(file_name) as file:
|
||||
# file_content = file.read().decode('utf-8') # Decode bytes to string
|
||||
# else:
|
||||
# raise FileNotFoundError(f"'{file_name}' not found in the ZIP archive.")
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
src_folder = REPO_DOWNLOAD_DIR + f'data/{instance.instance_id}/post/'
|
||||
shutil.copytree(src_folder, tmpdir, dirs_exist_ok=True)
|
||||
|
||||
# image = capture_screenshot(tmpdir)
|
||||
# if image is not None:
|
||||
# shutil.copy(os.path.join(tmpdir, 'final_screenshot.png'), REPO_DOWNLOAD_DIR)
|
||||
|
||||
|
||||
def process_instance(
|
||||
instance: pd.Series, metadata: EvalMetadata, reset_logger: bool = True
|
||||
):
|
||||
config = get_config(metadata)
|
||||
|
||||
# Setup the logger properly, so you can run multi-processing to parallelize the evaluation
|
||||
if reset_logger:
|
||||
log_dir = os.path.join(metadata.eval_output_dir, 'infer_logs')
|
||||
reset_logger_for_multiprocessing(logger, instance.instance_id, log_dir)
|
||||
else:
|
||||
logger.info(f'Starting evaluation for instance {instance.instance_id}.')
|
||||
|
||||
# =============================================
|
||||
# build instruction
|
||||
# =============================================
|
||||
|
||||
# Prepare instruction
|
||||
instruction = (
|
||||
f"Modify the HTML/CSS according to the following instruction:\n\n"
|
||||
f"{instance['changes']}\n\n"
|
||||
)
|
||||
instruction += (
|
||||
'IMPORTANT: You should ONLY interact with the environment provided '
|
||||
'to you AND NEVER ASK FOR HUMAN HELP.\n'
|
||||
)
|
||||
|
||||
# =============================================
|
||||
# create sandbox and run the agent
|
||||
# =============================================
|
||||
|
||||
runtime: Runtime = create_runtime(config)
|
||||
call_async_from_sync(runtime.connect)
|
||||
|
||||
try:
|
||||
initialize_runtime(runtime, instance=instance)
|
||||
|
||||
image_urls = pil_image_to_base64(instance['prev_image'])
|
||||
|
||||
action = MessageAction(content=instruction, image_urls=image_urls)
|
||||
state: State | None = asyncio.run(
|
||||
run_controller(
|
||||
config=config,
|
||||
initial_user_action=action,
|
||||
runtime=runtime,
|
||||
fake_user_response_fn=FAKE_RESPONSES[metadata.agent_class],
|
||||
)
|
||||
)
|
||||
if state is None:
|
||||
raise ValueError('State should not be None.')
|
||||
|
||||
# =============================================
|
||||
# result evaluation
|
||||
# =============================================
|
||||
|
||||
return_val = complete_runtime(runtime, instance)
|
||||
logger.info(f'Return value {return_val}')
|
||||
finally:
|
||||
runtime.close()
|
||||
|
||||
# TODO: return EVAL output
|
||||
|
||||
|
||||
def main():
|
||||
"""Main function to run the evaluation."""
|
||||
# args = parse_args()
|
||||
args = parse_arguments()
|
||||
|
||||
logger.info(f"\n{'='*80}\nStarting VisualCodeBench Evaluation\n{'='*80}")
|
||||
logger.info(f'Agent: {args.agent_cls}')
|
||||
logger.info(f'Model: {args.llm_config}')
|
||||
logger.info(f'Max iterations: {args.max_iterations}')
|
||||
logger.info(f'Eval limit: {args.eval_n_limit}')
|
||||
logger.info(f'Num workers: {args.eval_num_workers}\n')
|
||||
logger.info(f'Eval output: {args.eval_output_dir}\n')
|
||||
|
||||
# Step 1: Download the entire repository once
|
||||
logger.info('Downloading repository...')
|
||||
download_repository()
|
||||
|
||||
# Step 2: Load Dataset
|
||||
logger.info('Loading dataset...')
|
||||
dataset = load_dataset(REPO_DOWNLOAD_DIR)
|
||||
|
||||
# Step 3: Prepare dataset
|
||||
llm_config = get_llm_config_arg(args.llm_config)
|
||||
if llm_config is None:
|
||||
logger.error(f'Could not find LLM config: {args.llm_config}')
|
||||
raise ValueError(f'Could not find LLM config: {args.llm_config}')
|
||||
|
||||
metadata = make_metadata(
|
||||
llm_config,
|
||||
'VisualCodeBench',
|
||||
args.agent_cls,
|
||||
args.max_iterations,
|
||||
args.eval_note,
|
||||
'evaluation/output/',
|
||||
)
|
||||
|
||||
output_file = os.path.join(metadata.eval_output_dir, 'output.jsonl')
|
||||
dataset = prepare_visualcodebench(dataset)
|
||||
instances = prepare_dataset(dataset, output_file, eval_n_limit=args.eval_n_limit)
|
||||
|
||||
# Step 4: Run eval
|
||||
run_evaluation(
|
||||
instances, metadata, output_file, args.eval_num_workers, process_instance
|
||||
)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
@@ -0,0 +1,46 @@
|
||||
#!/bin/bash
|
||||
set -eo pipefail
|
||||
|
||||
source "evaluation/utils/version_control.sh"
|
||||
|
||||
# Check if required arguments are provided
|
||||
if [ "$#" -lt 4 ]; then
|
||||
echo "Usage: $0 [model_config] [commit_hash] [agent_cls] [eval_limit] [num_workers]"
|
||||
echo "Example: $0 llm.eval_gpt_4o_mini HEAD CodeActAgent 5 1"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
MODEL_CONFIG=$1
|
||||
COMMIT_HASH=$2
|
||||
AGENT_CLS=$3
|
||||
EVAL_LIMIT=$4
|
||||
NUM_WORKERS=${5:-1} # Default to 1 worker if not specified
|
||||
|
||||
# Checkout the specified commit
|
||||
checkout_eval_branch
|
||||
|
||||
if [ -z "$AGENT" ]; then
|
||||
echo "Agent not specified, use default CodeActAgent"
|
||||
AGENT="CodeActAgent"
|
||||
fi
|
||||
|
||||
get_openhands_version
|
||||
|
||||
echo "AGENT: $AGENT"
|
||||
echo "OPENHANDS_VERSION: $OPENHANDS_VERSION"
|
||||
echo "MODEL_CONFIG: $MODEL_CONFIG"
|
||||
|
||||
COMMAND="export PYTHONPATH=evaluation/benchmarks/visualcodebench:\$PYTHONPATH && poetry run python evaluation/benchmarks/visualcodebench/run_infer.py \
|
||||
--agent-cls $AGENT \
|
||||
--llm-config $MODEL_CONFIG \
|
||||
--max-iterations 5 \
|
||||
--eval-num-workers $NUM_WORKERS \
|
||||
--eval-note $OPENHANDS_VERSION" \
|
||||
|
||||
if [ -n "$EVAL_LIMIT" ]; then
|
||||
echo "EVAL_LIMIT: $EVAL_LIMIT"
|
||||
COMMAND="$COMMAND --eval-n-limit $EVAL_LIMIT"
|
||||
fi
|
||||
|
||||
# Run the command
|
||||
eval $COMMAND
|
||||
@@ -0,0 +1,167 @@
|
||||
import http
|
||||
import os
|
||||
import socket
|
||||
import socketserver
|
||||
import threading
|
||||
import time
|
||||
from io import BytesIO
|
||||
|
||||
import requests
|
||||
from PIL import Image, ImageChops
|
||||
from playwright.sync_api import sync_playwright
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
|
||||
|
||||
def get_free_port():
|
||||
"""Find a free port to run the HTTP server."""
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||
s.bind(('', 0))
|
||||
return s.getsockname()[1]
|
||||
|
||||
|
||||
def start_http_server(tmpdir):
|
||||
port = get_free_port()
|
||||
|
||||
class CustomHTTPRequestHandler(http.server.SimpleHTTPRequestHandler):
|
||||
def translate_path(self, path):
|
||||
# Serve files from the specified directory instead of the current working directory
|
||||
path = super().translate_path(path)
|
||||
relative_path = os.path.relpath(path, os.getcwd())
|
||||
return os.path.join(tmpdir, relative_path)
|
||||
|
||||
handler = CustomHTTPRequestHandler
|
||||
server = socketserver.TCPServer(('', port), handler)
|
||||
return server, port
|
||||
|
||||
|
||||
def capture_screenshot(tmpdir):
|
||||
server, port = start_http_server(tmpdir)
|
||||
server_thread = threading.Thread(target=server.serve_forever)
|
||||
server_thread.daemon = True
|
||||
server_thread.start()
|
||||
time.sleep(10)
|
||||
|
||||
image = None
|
||||
try:
|
||||
server_url = f'http://localhost:{port}/'
|
||||
|
||||
if not is_server_reachable(server_url):
|
||||
raise RuntimeError(f'Server not reachable at {server_url}')
|
||||
|
||||
screenshot_path = os.path.join(tmpdir, 'final_screenshot.png')
|
||||
capture_screenshot_playwright(server_url, screenshot_path)
|
||||
image = Image.open(screenshot_path)
|
||||
image.load()
|
||||
finally:
|
||||
# Shut down the server and clean up
|
||||
server.shutdown()
|
||||
server.server_close()
|
||||
|
||||
return image
|
||||
|
||||
|
||||
def is_server_reachable(url):
|
||||
"""
|
||||
Check if the local server is reachable.
|
||||
"""
|
||||
try:
|
||||
response = requests.get(url, timeout=5) # Set a 5-second timeout
|
||||
if response.status_code == 200:
|
||||
logger.info(f'Server is reachable at {url}')
|
||||
return True
|
||||
else:
|
||||
logger.warning(
|
||||
f'Server responded with status code {response.status_code} at {url}'
|
||||
)
|
||||
return False
|
||||
except requests.ConnectionError as e:
|
||||
logger.error(f'Failed to connect to server at {url}: {e}')
|
||||
return False
|
||||
|
||||
|
||||
def capture_screenshot_playwright(url, screenshot_path):
|
||||
"""Capture a screenshot of the given URL using Playwright."""
|
||||
try:
|
||||
with sync_playwright() as p:
|
||||
logger.info('Launching browser...')
|
||||
browser = p.chromium.launch(timeout=10000) # 10 seconds for browser launch
|
||||
|
||||
logger.info('Creating a new page...')
|
||||
page = browser.new_page()
|
||||
|
||||
logger.info(f'Navigating to URL: {url}')
|
||||
try:
|
||||
page.goto(url, timeout=60 * 1000) # Set timeout to 5 seconds
|
||||
logger.info('Page navigation completed.')
|
||||
except Exception as e:
|
||||
logger.warning(f'Page navigation timed out. {e}. Continuing...')
|
||||
|
||||
logger.info('Waiting for network to be idle...')
|
||||
try:
|
||||
page.wait_for_load_state(
|
||||
'networkidle', timeout=60 * 1000
|
||||
) # Set timeout to 5 seconds
|
||||
logger.info('Page load state reached.')
|
||||
except Exception as e:
|
||||
logger.warning(f'Page load state timed out. {e}. Continuing...')
|
||||
|
||||
logger.info('Capturing screenshot...')
|
||||
page.screenshot(
|
||||
path=screenshot_path, full_page=True
|
||||
) # Capture full page screenshot
|
||||
|
||||
logger.info(f'Screenshot saved to {screenshot_path}')
|
||||
browser.close()
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f'Error capturing screenshot with Playwright: {e}')
|
||||
return False
|
||||
|
||||
|
||||
def evaluate(task, screenshot_path):
|
||||
"""Compare generated screenshot with post_image using CLIP score."""
|
||||
try:
|
||||
import torch
|
||||
from transformers import CLIPModel, CLIPProcessor
|
||||
|
||||
# Load CLIP model and processor
|
||||
model = CLIPModel.from_pretrained('openai/clip-vit-base-patch32')
|
||||
processor = CLIPProcessor.from_pretrained('openai/clip-vit-base-patch32')
|
||||
|
||||
# Load images
|
||||
post_image = Image.open(BytesIO(task['post_image']))
|
||||
generated_img = Image.open(screenshot_path)
|
||||
|
||||
# Process images
|
||||
inputs = processor(
|
||||
images=[post_image, generated_img], return_tensors='pt', padding=True
|
||||
)
|
||||
|
||||
# Get image features
|
||||
image_features = model.get_image_features(**inputs)
|
||||
|
||||
# Calculate cosine similarity
|
||||
similarity = torch.nn.functional.cosine_similarity(
|
||||
image_features[0].unsqueeze(0), image_features[1].unsqueeze(0)
|
||||
).item()
|
||||
|
||||
logger.info(f'CLIP similarity score: {similarity}')
|
||||
|
||||
return similarity > 0.95 # Consider it a match if similarity > 95%
|
||||
except Exception as e:
|
||||
logger.error(f'Error in CLIP evaluation: {e}')
|
||||
# Fallback to pixel comparison if CLIP fails
|
||||
try:
|
||||
post_image = Image.open(BytesIO(task['post_image']))
|
||||
generated_img = Image.open(screenshot_path)
|
||||
|
||||
# Compare images directly without converting to bytes
|
||||
diff = ImageChops.difference(generated_img, post_image)
|
||||
logger.info(
|
||||
f"Pixel difference analysis: {'No difference' if not diff.getbbox() else 'Differences found'}"
|
||||
)
|
||||
return not diff.getbbox()
|
||||
except Exception as ex:
|
||||
logger.error(f'Error in fallback evaluation: {ex}')
|
||||
return False
|
||||
@@ -155,7 +155,9 @@ describe("Sidebar", () => {
|
||||
const settingsModal = screen.getByTestId("ai-config-modal");
|
||||
|
||||
// Click the advanced options switch to show the API key input
|
||||
const advancedOptionsSwitch = within(settingsModal).getByTestId("advanced-option-switch");
|
||||
const advancedOptionsSwitch = within(settingsModal).getByTestId(
|
||||
"advanced-option-switch",
|
||||
);
|
||||
await user.click(advancedOptionsSwitch);
|
||||
|
||||
const apiKeyInput = within(settingsModal).getByLabelText(/API\$KEY/i);
|
||||
|
||||
@@ -1,20 +1,20 @@
|
||||
import { describe, it, expect } from "vitest";
|
||||
import store from "../src/store";
|
||||
import {
|
||||
setInitialQuery,
|
||||
clearInitialQuery,
|
||||
setInitialPrompt,
|
||||
clearInitialPrompt,
|
||||
} from "../src/state/initial-query-slice";
|
||||
|
||||
describe("Initial Query Behavior", () => {
|
||||
it("should clear initial query when clearInitialQuery is dispatched", () => {
|
||||
it("should clear initial query when clearInitialPrompt is dispatched", () => {
|
||||
// Set up initial query in the store
|
||||
store.dispatch(setInitialQuery("test query"));
|
||||
expect(store.getState().initialQuery.initialQuery).toBe("test query");
|
||||
store.dispatch(setInitialPrompt("test query"));
|
||||
expect(store.getState().initialQuery.initialPrompt).toBe("test query");
|
||||
|
||||
// Clear the initial query
|
||||
store.dispatch(clearInitialQuery());
|
||||
store.dispatch(clearInitialPrompt());
|
||||
|
||||
// Verify initial query is cleared
|
||||
expect(store.getState().initialQuery.initialQuery).toBeNull();
|
||||
expect(store.getState().initialQuery.initialPrompt).toBeNull();
|
||||
});
|
||||
});
|
||||
|
||||
@@ -244,10 +244,14 @@ class OpenHands {
|
||||
static async createConversation(
|
||||
githubToken?: string,
|
||||
selectedRepository?: string,
|
||||
initialUserMsg?: string,
|
||||
imageUrls?: string[],
|
||||
): Promise<Conversation> {
|
||||
const body = {
|
||||
github_token: githubToken,
|
||||
selected_repository: selectedRepository,
|
||||
initial_user_msg: initialUserMsg,
|
||||
image_urls: imageUrls,
|
||||
};
|
||||
|
||||
const { data } = await openHands.post<Conversation>(
|
||||
|
||||
@@ -23,7 +23,7 @@ export const AGENT_STATUS_MAP: {
|
||||
},
|
||||
[AgentState.AWAITING_USER_INPUT]: {
|
||||
message: I18nKey.CHAT_INTERFACE$AGENT_AWAITING_USER_INPUT_MESSAGE,
|
||||
indicator: IndicatorColor.ORANGE,
|
||||
indicator: IndicatorColor.BLUE,
|
||||
},
|
||||
[AgentState.PAUSED]: {
|
||||
message: I18nKey.CHAT_INTERFACE$AGENT_PAUSED_MESSAGE,
|
||||
|
||||
@@ -12,15 +12,22 @@ interface MessagesProps {
|
||||
export const Messages: React.FC<MessagesProps> = React.memo(
|
||||
({ messages, isAwaitingUserConfirmation }) =>
|
||||
messages.map((message, index) => {
|
||||
const shouldShowConfirmationButtons =
|
||||
messages.length - 1 === index &&
|
||||
message.sender === "assistant" &&
|
||||
isAwaitingUserConfirmation;
|
||||
|
||||
if (message.type === "error" || message.type === "action") {
|
||||
return (
|
||||
<ExpandableMessage
|
||||
key={index}
|
||||
type={message.type}
|
||||
id={message.translationID}
|
||||
message={message.content}
|
||||
success={message.success}
|
||||
/>
|
||||
<div key={index}>
|
||||
<ExpandableMessage
|
||||
type={message.type}
|
||||
id={message.translationID}
|
||||
message={message.content}
|
||||
success={message.success}
|
||||
/>
|
||||
{shouldShowConfirmationButtons && <ConfirmationButtons />}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -33,9 +40,7 @@ export const Messages: React.FC<MessagesProps> = React.memo(
|
||||
{message.imageUrls && message.imageUrls.length > 0 && (
|
||||
<ImageCarousel size="small" images={message.imageUrls} />
|
||||
)}
|
||||
{messages.length - 1 === index &&
|
||||
message.sender === "assistant" &&
|
||||
isAwaitingUserConfirmation && <ConfirmationButtons />}
|
||||
{shouldShowConfirmationButtons && <ConfirmationButtons />}
|
||||
</ChatMessage>
|
||||
);
|
||||
}),
|
||||
|
||||
@@ -43,7 +43,7 @@ export function AgentStatusBar() {
|
||||
|
||||
React.useEffect(() => {
|
||||
if (status === WsClientProviderStatus.DISCONNECTED) {
|
||||
setStatusMessage("Trying to reconnect...");
|
||||
setStatusMessage("Connecting...");
|
||||
} else {
|
||||
setStatusMessage(AGENT_STATUS_MAP[curAgentState].message);
|
||||
}
|
||||
|
||||
@@ -3,7 +3,7 @@ import { useNavigate } from "react-router";
|
||||
import posthog from "posthog-js";
|
||||
import { useDispatch, useSelector } from "react-redux";
|
||||
import OpenHands from "#/api/open-hands";
|
||||
import { setInitialQuery } from "#/state/initial-query-slice";
|
||||
import { setInitialPrompt } from "#/state/initial-query-slice";
|
||||
import { RootState } from "#/store";
|
||||
import { useAuth } from "#/context/auth-context";
|
||||
|
||||
@@ -18,7 +18,7 @@ export const useCreateConversation = () => {
|
||||
);
|
||||
|
||||
return useMutation({
|
||||
mutationFn: (variables: { q?: string }) => {
|
||||
mutationFn: async (variables: { q?: string }) => {
|
||||
if (
|
||||
!variables.q?.trim() &&
|
||||
!selectedRepository &&
|
||||
@@ -28,10 +28,13 @@ export const useCreateConversation = () => {
|
||||
throw new Error("No query provided");
|
||||
}
|
||||
|
||||
if (variables.q) dispatch(setInitialQuery(variables.q));
|
||||
if (variables.q) dispatch(setInitialPrompt(variables.q));
|
||||
|
||||
return OpenHands.createConversation(
|
||||
gitHubToken || undefined,
|
||||
selectedRepository || undefined,
|
||||
variables.q,
|
||||
files,
|
||||
);
|
||||
},
|
||||
onSuccess: async ({ conversation_id: conversationId }, { q }) => {
|
||||
|
||||
@@ -6,7 +6,7 @@ import { useConfig } from "./use-config";
|
||||
import OpenHands from "#/api/open-hands";
|
||||
|
||||
export const useGitHubUser = () => {
|
||||
const { gitHubToken, setUserId } = useAuth();
|
||||
const { gitHubToken, setUserId, logout } = useAuth();
|
||||
const { data: config } = useConfig();
|
||||
|
||||
const user = useQuery({
|
||||
@@ -29,5 +29,11 @@ export const useGitHubUser = () => {
|
||||
}
|
||||
}, [user.data]);
|
||||
|
||||
React.useEffect(() => {
|
||||
if (user.isError) {
|
||||
logout();
|
||||
}
|
||||
}, [user.isError]);
|
||||
|
||||
return user;
|
||||
};
|
||||
|
||||
@@ -141,7 +141,7 @@ export const handlers = [
|
||||
{ id: 2, full_name: "octocat/earth" },
|
||||
]),
|
||||
),
|
||||
http.get("https://api.github.com/user", () => {
|
||||
http.get("/api/github/user", () => {
|
||||
const user: GitHubUser = {
|
||||
id: 1,
|
||||
login: "octocat",
|
||||
|
||||
@@ -1,10 +1,8 @@
|
||||
import React from "react";
|
||||
import { useWSStatusChange } from "./hooks/use-ws-status-change";
|
||||
import { useHandleWSEvents } from "./hooks/use-handle-ws-events";
|
||||
import { useHandleRuntimeActive } from "./hooks/use-handle-runtime-active";
|
||||
|
||||
export function EventHandler({ children }: React.PropsWithChildren) {
|
||||
useWSStatusChange();
|
||||
useHandleWSEvents();
|
||||
useHandleRuntimeActive();
|
||||
|
||||
|
||||
@@ -1,68 +0,0 @@
|
||||
import React from "react";
|
||||
import { useDispatch, useSelector } from "react-redux";
|
||||
import {
|
||||
useWsClient,
|
||||
WsClientProviderStatus,
|
||||
} from "#/context/ws-client-provider";
|
||||
import { createChatMessage } from "#/services/chat-service";
|
||||
import { setCurrentAgentState } from "#/state/agent-slice";
|
||||
import { addUserMessage } from "#/state/chat-slice";
|
||||
import { clearFiles, clearInitialQuery } from "#/state/initial-query-slice";
|
||||
import { RootState } from "#/store";
|
||||
import { AgentState } from "#/types/agent-state";
|
||||
|
||||
export const useWSStatusChange = () => {
|
||||
const { send, status } = useWsClient();
|
||||
const { curAgentState } = useSelector((state: RootState) => state.agent);
|
||||
const dispatch = useDispatch();
|
||||
|
||||
const statusRef = React.useRef<WsClientProviderStatus | null>(null);
|
||||
|
||||
const { files, initialQuery } = useSelector(
|
||||
(state: RootState) => state.initialQuery,
|
||||
);
|
||||
|
||||
const sendInitialQuery = (query: string, base64Files: string[]) => {
|
||||
const timestamp = new Date().toISOString();
|
||||
send(createChatMessage(query, base64Files, timestamp));
|
||||
};
|
||||
|
||||
const dispatchInitialQuery = (query: string) => {
|
||||
sendInitialQuery(query, files);
|
||||
dispatch(clearFiles()); // reset selected files
|
||||
dispatch(clearInitialQuery()); // reset initial query
|
||||
};
|
||||
|
||||
const handleAgentInit = () => {
|
||||
if (initialQuery) {
|
||||
dispatchInitialQuery(initialQuery);
|
||||
}
|
||||
};
|
||||
React.useEffect(() => {
|
||||
if (curAgentState === AgentState.INIT) {
|
||||
handleAgentInit();
|
||||
}
|
||||
}, [curAgentState]);
|
||||
|
||||
React.useEffect(() => {
|
||||
if (statusRef.current === status) {
|
||||
return; // This is a check because of strict mode - if the status did not change, don't do anything
|
||||
}
|
||||
statusRef.current = status;
|
||||
|
||||
if (status !== WsClientProviderStatus.DISCONNECTED && initialQuery) {
|
||||
dispatch(
|
||||
addUserMessage({
|
||||
content: initialQuery,
|
||||
imageUrls: files,
|
||||
timestamp: new Date().toISOString(),
|
||||
pending: true,
|
||||
}),
|
||||
);
|
||||
}
|
||||
|
||||
if (status === WsClientProviderStatus.DISCONNECTED) {
|
||||
dispatch(setCurrentAgentState(AgentState.STOPPED));
|
||||
}
|
||||
}, [status]);
|
||||
};
|
||||
@@ -1,7 +1,7 @@
|
||||
import { useDisclosure } from "@nextui-org/react";
|
||||
import React from "react";
|
||||
import { Outlet } from "react-router";
|
||||
import { useDispatch } from "react-redux";
|
||||
import { useDispatch, useSelector } from "react-redux";
|
||||
import { FaServer } from "react-icons/fa";
|
||||
import toast from "react-hot-toast";
|
||||
import { useTranslation } from "react-i18next";
|
||||
@@ -11,7 +11,7 @@ import {
|
||||
useConversation,
|
||||
} from "#/context/conversation-context";
|
||||
import { Controls } from "#/components/features/controls/controls";
|
||||
import { clearMessages } from "#/state/chat-slice";
|
||||
import { clearMessages, addUserMessage } from "#/state/chat-slice";
|
||||
import { clearTerminal } from "#/state/command-slice";
|
||||
import { useEffectOnce } from "#/hooks/use-effect-once";
|
||||
import CodeIcon from "#/icons/code.svg?react";
|
||||
@@ -36,6 +36,8 @@ import { ServedAppLabel } from "#/components/layout/served-app-label";
|
||||
import { TerminalStatusLabel } from "#/components/features/terminal/terminal-status-label";
|
||||
import { useSettings } from "#/hooks/query/use-settings";
|
||||
import { MULTI_CONVERSATION_UI } from "#/utils/feature-flags";
|
||||
import { clearFiles, clearInitialPrompt } from "#/state/initial-query-slice";
|
||||
import { RootState } from "#/store";
|
||||
|
||||
function AppContent() {
|
||||
useConversationConfig();
|
||||
@@ -46,6 +48,9 @@ function AppContent() {
|
||||
const { data: conversation, isFetched } = useUserConversation(
|
||||
conversationId || null,
|
||||
);
|
||||
const { initialPrompt, files } = useSelector(
|
||||
(state: RootState) => state.initialQuery,
|
||||
);
|
||||
const dispatch = useDispatch();
|
||||
const endSession = useEndSession();
|
||||
|
||||
@@ -74,6 +79,18 @@ function AppContent() {
|
||||
dispatch(clearMessages());
|
||||
dispatch(clearTerminal());
|
||||
dispatch(clearJupyter());
|
||||
if (conversationId && (initialPrompt || files.length > 0)) {
|
||||
dispatch(
|
||||
addUserMessage({
|
||||
content: initialPrompt || "",
|
||||
imageUrls: files || [],
|
||||
timestamp: new Date().toISOString(),
|
||||
pending: true,
|
||||
}),
|
||||
);
|
||||
dispatch(clearInitialPrompt());
|
||||
dispatch(clearFiles());
|
||||
}
|
||||
}, [conversationId]);
|
||||
|
||||
useEffectOnce(() => {
|
||||
|
||||
@@ -2,14 +2,14 @@ import { createSlice, PayloadAction } from "@reduxjs/toolkit";
|
||||
|
||||
type SliceState = {
|
||||
files: string[]; // base64 encoded images
|
||||
initialQuery: string | null;
|
||||
initialPrompt: string | null;
|
||||
selectedRepository: string | null;
|
||||
importedProjectZip: string | null; // base64 encoded zip
|
||||
};
|
||||
|
||||
const initialState: SliceState = {
|
||||
files: [],
|
||||
initialQuery: null,
|
||||
initialPrompt: null,
|
||||
selectedRepository: null,
|
||||
importedProjectZip: null,
|
||||
};
|
||||
@@ -27,11 +27,11 @@ export const selectedFilesSlice = createSlice({
|
||||
clearFiles(state) {
|
||||
state.files = [];
|
||||
},
|
||||
setInitialQuery(state, action: PayloadAction<string>) {
|
||||
state.initialQuery = action.payload;
|
||||
setInitialPrompt(state, action: PayloadAction<string>) {
|
||||
state.initialPrompt = action.payload;
|
||||
},
|
||||
clearInitialQuery(state) {
|
||||
state.initialQuery = null;
|
||||
clearInitialPrompt(state) {
|
||||
state.initialPrompt = null;
|
||||
},
|
||||
setSelectedRepository(state, action: PayloadAction<string | null>) {
|
||||
state.selectedRepository = action.payload;
|
||||
@@ -49,8 +49,8 @@ export const {
|
||||
addFile,
|
||||
removeFile,
|
||||
clearFiles,
|
||||
setInitialQuery,
|
||||
clearInitialQuery,
|
||||
setInitialPrompt,
|
||||
clearInitialPrompt,
|
||||
setSelectedRepository,
|
||||
clearSelectedRepository,
|
||||
setImportedProjectZip,
|
||||
|
||||
+1
-1
@@ -20,7 +20,7 @@ The key classes in OpenHands are:
|
||||
* Sandbox: the part of the runtime responsible for running commands, e.g. inside of Docker
|
||||
* Server: brokers OpenHands sessions over HTTP, e.g. to drive the frontend
|
||||
* Session: holds a single EventStream, a single AgentController, and a single Runtime. Generally represents a single task (but potentially including several user prompts)
|
||||
* SessionManager: keeps a list of active sessions, and ensures requests are routed to the correct Session
|
||||
* ConversationManager: keeps a list of active sessions, and ensures requests are routed to the correct Session
|
||||
|
||||
## Control Flow
|
||||
Here's the basic loop (in pseudocode) that drives agents.
|
||||
|
||||
@@ -32,6 +32,7 @@ from openhands.events.tool import ToolCallMetadata
|
||||
_BASH_DESCRIPTION = """Execute a bash command in the terminal.
|
||||
* Long running commands: For commands that may run indefinitely, it should be run in the background and the output should be redirected to a file, e.g. command = `python3 app.py > server.log 2>&1 &`.
|
||||
* Interact with running process: If a bash command returns exit code `-1`, this means the process is not yet finished. By setting `is_input` to `true`, the assistant can interact with the running process and send empty `command` to retrieve any additional logs, or send additional text (set `command` to the text) to STDIN of the running process, or send command like `C-c` (Ctrl+C), `C-d` (Ctrl+D), `C-z` (Ctrl+Z) to interrupt the process.
|
||||
* One command at a time: You can only execute one bash command at a time. If you need to run multiple commands sequentially, you can use `&&` or `;` to chain them together.
|
||||
"""
|
||||
|
||||
CmdRunTool = ChatCompletionToolParam(
|
||||
@@ -44,7 +45,7 @@ CmdRunTool = ChatCompletionToolParam(
|
||||
'properties': {
|
||||
'command': {
|
||||
'type': 'string',
|
||||
'description': 'The bash command to execute. Can be empty string to view additional logs when previous exit code is `-1`. Can be `C-c` (Ctrl+C) to interrupt the currently running process.',
|
||||
'description': 'The bash command to execute. Can be empty string to view additional logs when previous exit code is `-1`. Can be `C-c` (Ctrl+C) to interrupt the currently running process. Note: You can only execute one bash command at a time. If you need to run multiple commands sequentially, you can use `&&` or `;` to chain them together.',
|
||||
},
|
||||
'is_input': {
|
||||
'type': 'string',
|
||||
|
||||
@@ -501,10 +501,6 @@ class AgentController:
|
||||
EventSource.ENVIRONMENT,
|
||||
)
|
||||
|
||||
if new_state == AgentState.INIT and self.state.resume_state:
|
||||
await self.set_agent_state_to(self.state.resume_state)
|
||||
self.state.resume_state = None
|
||||
|
||||
def get_agent_state(self) -> AgentState:
|
||||
"""Returns the current state of the agent.
|
||||
|
||||
|
||||
@@ -4,10 +4,6 @@ __all__ = ['ActionType']
|
||||
|
||||
|
||||
class ActionTypeSchema(BaseModel):
|
||||
INIT: str = Field(default='initialize')
|
||||
"""Initializes the agent. Only sent by client.
|
||||
"""
|
||||
|
||||
MESSAGE: str = Field(default='message')
|
||||
"""Represents a message.
|
||||
"""
|
||||
|
||||
@@ -6,10 +6,6 @@ class AgentState(str, Enum):
|
||||
"""The agent is loading.
|
||||
"""
|
||||
|
||||
INIT = 'init'
|
||||
"""The agent is initialized.
|
||||
"""
|
||||
|
||||
RUNNING = 'running'
|
||||
"""The agent is running.
|
||||
"""
|
||||
|
||||
@@ -18,3 +18,4 @@ class GithubIssue(BaseModel):
|
||||
review_threads: list[ReviewThread] | None = None
|
||||
thread_ids: list[str] | None = None
|
||||
head_branch: str | None = None
|
||||
base_branch: str | None = None
|
||||
|
||||
@@ -331,9 +331,10 @@ def main():
|
||||
if not token:
|
||||
raise ValueError('Github token is required.')
|
||||
|
||||
api_key = my_args.llm_api_key or os.environ['LLM_API_KEY']
|
||||
llm_config = LLMConfig(
|
||||
model=my_args.llm_model or os.environ['LLM_MODEL'],
|
||||
api_key=my_args.llm_api_key or os.environ['LLM_API_KEY'],
|
||||
api_key=str(api_key) if api_key else None,
|
||||
base_url=my_args.llm_base_url or os.environ.get('LLM_BASE_URL', None),
|
||||
)
|
||||
|
||||
|
||||
@@ -307,7 +307,6 @@ async def resolve_issue(
|
||||
repo_instruction: str | None,
|
||||
issue_number: int,
|
||||
comment_id: int | None,
|
||||
target_branch: str | None = None,
|
||||
reset_logger: bool = False,
|
||||
) -> None:
|
||||
"""Resolve a single github issue.
|
||||
@@ -326,7 +325,7 @@ async def resolve_issue(
|
||||
repo_instruction: Repository instruction to use.
|
||||
issue_number: Issue number to resolve.
|
||||
comment_id: Optional ID of a specific comment to focus on.
|
||||
target_branch: Optional target branch to create PR against (for PRs).
|
||||
|
||||
reset_logger: Whether to reset the logger for multiprocessing.
|
||||
"""
|
||||
issue_handler = issue_handler_factory(issue_type, owner, repo, token, llm_config)
|
||||
@@ -424,9 +423,9 @@ async def resolve_issue(
|
||||
try:
|
||||
# checkout to pr branch if needed
|
||||
if issue_type == 'pr':
|
||||
branch_to_use = target_branch if target_branch else issue.head_branch
|
||||
branch_to_use = issue.head_branch
|
||||
logger.info(
|
||||
f'Checking out to PR branch {target_branch} for issue {issue.number}'
|
||||
f'Checking out to PR branch {branch_to_use} for issue {issue.number}'
|
||||
)
|
||||
|
||||
if not branch_to_use:
|
||||
@@ -446,10 +445,6 @@ async def resolve_issue(
|
||||
cwd=repo_dir,
|
||||
)
|
||||
|
||||
# Update issue's base_branch if using custom target branch
|
||||
if target_branch:
|
||||
issue.base_branch = target_branch
|
||||
|
||||
base_commit = (
|
||||
subprocess.check_output(['git', 'rev-parse', 'HEAD'], cwd=repo_dir)
|
||||
.decode('utf-8')
|
||||
@@ -572,12 +567,6 @@ def main():
|
||||
choices=['issue', 'pr'],
|
||||
help='Type of issue to resolve, either open issue or pr comments.',
|
||||
)
|
||||
parser.add_argument(
|
||||
'--target-branch',
|
||||
type=str,
|
||||
default=None,
|
||||
help="Target branch to pull and create PR against (for PRs). If not specified, uses the PR's base branch.",
|
||||
)
|
||||
parser.add_argument(
|
||||
'--is-experimental',
|
||||
type=lambda x: x.lower() == 'true',
|
||||
@@ -601,9 +590,10 @@ def main():
|
||||
if not token:
|
||||
raise ValueError('Github token is required.')
|
||||
|
||||
api_key = my_args.llm_api_key or os.environ['LLM_API_KEY']
|
||||
llm_config = LLMConfig(
|
||||
model=my_args.llm_model or os.environ['LLM_MODEL'],
|
||||
api_key=my_args.llm_api_key or os.environ['LLM_API_KEY'],
|
||||
api_key=str(api_key) if api_key else None,
|
||||
base_url=my_args.llm_base_url or os.environ.get('LLM_BASE_URL', None),
|
||||
)
|
||||
|
||||
@@ -643,7 +633,6 @@ def main():
|
||||
repo_instruction=repo_instruction,
|
||||
issue_number=my_args.issue_number,
|
||||
comment_id=my_args.comment_id,
|
||||
target_branch=my_args.target_branch,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@@ -719,9 +719,10 @@ def main():
|
||||
else os.getenv('GITHUB_USERNAME')
|
||||
)
|
||||
|
||||
api_key = my_args.llm_api_key or os.environ['LLM_API_KEY']
|
||||
llm_config = LLMConfig(
|
||||
model=my_args.llm_model or os.environ['LLM_MODEL'],
|
||||
api_key=my_args.llm_api_key or os.environ['LLM_API_KEY'],
|
||||
api_key=str(api_key) if api_key else None,
|
||||
base_url=my_args.llm_base_url or os.environ.get('LLM_BASE_URL', None),
|
||||
)
|
||||
|
||||
|
||||
@@ -40,6 +40,7 @@ class ModalRuntime(ActionExecutionClient):
|
||||
|
||||
container_name_prefix = 'openhands-sandbox-'
|
||||
sandbox: modal.Sandbox | None
|
||||
sid: str
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -57,6 +58,7 @@ class ModalRuntime(ActionExecutionClient):
|
||||
|
||||
self.config = config
|
||||
self.sandbox = None
|
||||
self.sid = sid
|
||||
|
||||
self.modal_client = modal.Client.from_credentials(
|
||||
config.modal_api_token_id.get_secret_value(),
|
||||
@@ -75,6 +77,8 @@ class ModalRuntime(ActionExecutionClient):
|
||||
|
||||
# This value is arbitrary as it's private to the container
|
||||
self.container_port = 3000
|
||||
self._vscode_port = 4445
|
||||
self._vscode_url: str | None = None
|
||||
|
||||
self.status_callback = status_callback
|
||||
self.base_container_image_id = self.config.sandbox.base_container_image
|
||||
@@ -140,6 +144,7 @@ class ModalRuntime(ActionExecutionClient):
|
||||
|
||||
if not self.attach_to_existing:
|
||||
self.send_status_message(' ')
|
||||
self._runtime_initialized = True
|
||||
|
||||
def _get_action_execution_server_host(self):
|
||||
return self.api_url
|
||||
@@ -208,6 +213,7 @@ echo 'export INPUTRC=/etc/inputrc' >> /etc/bash.bashrc
|
||||
environment: dict[str, str | None] = {
|
||||
'port': str(self.container_port),
|
||||
'PYTHONUNBUFFERED': '1',
|
||||
'VSCODE_PORT': str(self._vscode_port),
|
||||
}
|
||||
if self.config.debug:
|
||||
environment['DEBUG'] = 'true'
|
||||
@@ -225,7 +231,7 @@ echo 'export INPUTRC=/etc/inputrc' >> /etc/bash.bashrc
|
||||
*sandbox_start_cmd,
|
||||
secrets=[env_secret],
|
||||
workdir='/openhands/code',
|
||||
encrypted_ports=[self.container_port],
|
||||
encrypted_ports=[self.container_port, self._vscode_port],
|
||||
image=self.image,
|
||||
app=self.app,
|
||||
client=self.modal_client,
|
||||
@@ -248,3 +254,27 @@ echo 'export INPUTRC=/etc/inputrc' >> /etc/bash.bashrc
|
||||
|
||||
if not self.attach_to_existing and self.sandbox:
|
||||
self.sandbox.terminate()
|
||||
|
||||
@property
|
||||
def vscode_url(self) -> str | None:
|
||||
if self._vscode_url is not None: # cached value
|
||||
self.log('debug', f'VSCode URL: {self._vscode_url}')
|
||||
return self._vscode_url
|
||||
token = super().get_vscode_token()
|
||||
if not token:
|
||||
self.log('error', 'VSCode token not found')
|
||||
return None
|
||||
if not self.sandbox:
|
||||
self.log('error', 'Sandbox not initialized')
|
||||
return None
|
||||
|
||||
tunnel = self.sandbox.tunnels()[self._vscode_port]
|
||||
tunnel_url = tunnel.url
|
||||
self._vscode_url = tunnel_url + f'/?tkn={token}&folder={self.config.workspace_mount_path_in_sandbox}'
|
||||
|
||||
self.log(
|
||||
'debug',
|
||||
f'VSCode URL: {self._vscode_url}',
|
||||
)
|
||||
|
||||
return self._vscode_url
|
||||
|
||||
@@ -306,6 +306,12 @@ class RemoteRuntime(ActionExecutionClient):
|
||||
assert 'pod_status' in runtime_data
|
||||
pod_status = runtime_data['pod_status'].lower()
|
||||
self.log('debug', f'Pod status: {pod_status}')
|
||||
restart_count = runtime_data.get('restart_count', 0)
|
||||
if restart_count != 0:
|
||||
restart_reasons = runtime_data.get('restart_reasons')
|
||||
self.log(
|
||||
'debug', f'Pod restarts: {restart_count}, reasons: {restart_reasons}'
|
||||
)
|
||||
|
||||
# FIXME: We should fix it at the backend of /start endpoint, make sure
|
||||
# the pod is created before returning the response.
|
||||
|
||||
@@ -125,13 +125,13 @@ The `agent_session.py` file contains the `AgentSession` class, which manages the
|
||||
- Handling security analysis
|
||||
- Managing the event stream
|
||||
|
||||
### 3. session/manager.py
|
||||
### 3. session/conversation_manager/conversation_manager.py
|
||||
|
||||
The `manager.py` file defines the `SessionManager` class, which is responsible for managing multiple client sessions. Key features include:
|
||||
The `conversation_manager.py` file defines the `ConversationManager` class, which is responsible for managing multiple client conversations. Key features include:
|
||||
|
||||
- Adding and restarting sessions
|
||||
- Sending messages to specific sessions
|
||||
- Cleaning up inactive sessions
|
||||
- Adding and restarting conversations
|
||||
- Sending messages to specific conversations
|
||||
- Cleaning up inactive conversations
|
||||
|
||||
### 4. listen.py
|
||||
|
||||
@@ -148,7 +148,7 @@ The `listen.py` file is the main server file that sets up the FastAPI applicatio
|
||||
1. **Server Initialization**:
|
||||
- The FastAPI application is created and configured in `listen.py`.
|
||||
- CORS middleware and static file serving are set up.
|
||||
- The `SessionManager` is initialized.
|
||||
- The `ConversationManager` is initialized.
|
||||
|
||||
2. **Client Connection**:
|
||||
- When a client connects via WebSocket, a new `Session` is created or an existing one is restarted.
|
||||
@@ -173,7 +173,7 @@ The `listen.py` file is the main server file that sets up the FastAPI applicatio
|
||||
- Security-related API requests are forwarded to the security analyzer.
|
||||
|
||||
7. **Session Management**:
|
||||
- The `SessionManager` periodically cleans up inactive sessions.
|
||||
- The `ConversationManager` periodically cleans up inactive sessions.
|
||||
- It also handles sending messages to specific sessions when needed.
|
||||
|
||||
8. **API Endpoints**:
|
||||
|
||||
@@ -21,12 +21,12 @@ from openhands.server.routes.public import app as public_api_router
|
||||
from openhands.server.routes.security import app as security_api_router
|
||||
from openhands.server.routes.settings import app as settings_router
|
||||
from openhands.server.routes.trajectory import app as trajectory_router
|
||||
from openhands.server.shared import openhands_config, session_manager
|
||||
from openhands.server.shared import conversation_manager, openhands_config
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def _lifespan(app: FastAPI):
|
||||
async with session_manager:
|
||||
async with conversation_manager:
|
||||
yield
|
||||
|
||||
|
||||
|
||||
@@ -25,6 +25,7 @@ class OpenhandsConfig(OpenhandsConfigInterface):
|
||||
conversation_store_class: str = (
|
||||
'openhands.storage.conversation.file_conversation_store.FileConversationStore'
|
||||
)
|
||||
conversation_manager_class: str = 'openhands.server.conversation_manager.standalone_conversation_manager.StandaloneConversationManager'
|
||||
|
||||
def verify_config(self):
|
||||
if self.config_cls:
|
||||
|
||||
@@ -0,0 +1,96 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
import socketio
|
||||
|
||||
from openhands.core.config import AppConfig
|
||||
from openhands.events.action import MessageAction
|
||||
from openhands.events.stream import EventStream
|
||||
from openhands.server.session.conversation import Conversation
|
||||
from openhands.server.settings import Settings
|
||||
from openhands.storage.files import FileStore
|
||||
|
||||
|
||||
class ConversationManager(ABC):
|
||||
"""Abstract base class for managing conversations in OpenHands.
|
||||
|
||||
This class defines the interface for managing conversations, whether in standalone
|
||||
or clustered mode. It handles the lifecycle of conversations, including creation,
|
||||
attachment, detachment, and cleanup.
|
||||
"""
|
||||
|
||||
sio: socketio.AsyncServer
|
||||
config: AppConfig
|
||||
file_store: FileStore
|
||||
|
||||
@abstractmethod
|
||||
async def __aenter__(self):
|
||||
"""Initialize the conversation manager."""
|
||||
|
||||
@abstractmethod
|
||||
async def __aexit__(self, exc_type, exc_value, traceback):
|
||||
"""Clean up the conversation manager."""
|
||||
|
||||
@abstractmethod
|
||||
async def attach_to_conversation(self, sid: str) -> Conversation | None:
|
||||
"""Attach to an existing conversation or create a new one."""
|
||||
|
||||
@abstractmethod
|
||||
async def detach_from_conversation(self, conversation: Conversation):
|
||||
"""Detach from a conversation."""
|
||||
|
||||
@abstractmethod
|
||||
async def join_conversation(
|
||||
self, sid: str, connection_id: str, settings: Settings, user_id: str | None
|
||||
) -> EventStream | None:
|
||||
"""Join a conversation and return its event stream."""
|
||||
|
||||
async def is_agent_loop_running(self, sid: str) -> bool:
|
||||
"""Check if an agent loop is running for the given session ID."""
|
||||
sids = await self.get_running_agent_loops(filter_to_sids={sid})
|
||||
return bool(sids)
|
||||
|
||||
@abstractmethod
|
||||
async def get_running_agent_loops(
|
||||
self, user_id: str | None = None, filter_to_sids: set[str] | None = None
|
||||
) -> set[str]:
|
||||
"""Get all running agent loops, optionally filtered by user ID and session IDs."""
|
||||
|
||||
@abstractmethod
|
||||
async def get_connections(
|
||||
self, user_id: str | None = None, filter_to_sids: set[str] | None = None
|
||||
) -> dict[str, str]:
|
||||
"""Get all connections, optionally filtered by user ID and session IDs."""
|
||||
|
||||
@abstractmethod
|
||||
async def maybe_start_agent_loop(
|
||||
self,
|
||||
sid: str,
|
||||
settings: Settings,
|
||||
user_id: str | None,
|
||||
initial_user_msg: MessageAction | None = None,
|
||||
) -> EventStream:
|
||||
"""Start an event loop if one is not already running"""
|
||||
|
||||
@abstractmethod
|
||||
async def send_to_event_stream(self, connection_id: str, data: dict):
|
||||
"""Send data to an event stream."""
|
||||
|
||||
@abstractmethod
|
||||
async def disconnect_from_session(self, connection_id: str):
|
||||
"""Disconnect from a session."""
|
||||
|
||||
@abstractmethod
|
||||
async def close_session(self, sid: str):
|
||||
"""Close a session."""
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def get_instance(
|
||||
cls,
|
||||
sio: socketio.AsyncServer,
|
||||
config: AppConfig,
|
||||
file_store: FileStore,
|
||||
) -> ConversationManager:
|
||||
"""Get a store for the user represented by the token given"""
|
||||
@@ -0,0 +1,284 @@
|
||||
import asyncio
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Iterable
|
||||
|
||||
import socketio
|
||||
|
||||
from openhands.core.config.app_config import AppConfig
|
||||
from openhands.core.exceptions import AgentRuntimeUnavailableError
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.core.schema.agent import AgentState
|
||||
from openhands.events.action import MessageAction
|
||||
from openhands.events.stream import EventStream, session_exists
|
||||
from openhands.server.session.conversation import Conversation
|
||||
from openhands.server.session.session import ROOM_KEY, Session
|
||||
from openhands.server.settings import Settings
|
||||
from openhands.storage.files import FileStore
|
||||
from openhands.utils.async_utils import wait_all
|
||||
from openhands.utils.shutdown_listener import should_continue
|
||||
|
||||
from .conversation_manager import ConversationManager
|
||||
|
||||
_CLEANUP_INTERVAL = 15
|
||||
MAX_RUNNING_CONVERSATIONS = 3
|
||||
|
||||
|
||||
@dataclass
|
||||
class StandaloneConversationManager(ConversationManager):
|
||||
"""Manages conversations in standalone mode (single server instance)."""
|
||||
|
||||
sio: socketio.AsyncServer
|
||||
config: AppConfig
|
||||
file_store: FileStore
|
||||
_local_agent_loops_by_sid: dict[str, Session] = field(default_factory=dict)
|
||||
_local_connection_id_to_session_id: dict[str, str] = field(default_factory=dict)
|
||||
_active_conversations: dict[str, tuple[Conversation, int]] = field(
|
||||
default_factory=dict
|
||||
)
|
||||
_detached_conversations: dict[str, tuple[Conversation, float]] = field(
|
||||
default_factory=dict
|
||||
)
|
||||
_conversations_lock: asyncio.Lock = field(default_factory=asyncio.Lock)
|
||||
_cleanup_task: asyncio.Task | None = None
|
||||
|
||||
async def __aenter__(self):
|
||||
self._cleanup_task = asyncio.create_task(self._cleanup_stale())
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_value, traceback):
|
||||
if self._cleanup_task:
|
||||
self._cleanup_task.cancel()
|
||||
self._cleanup_task = None
|
||||
|
||||
async def attach_to_conversation(self, sid: str) -> Conversation | None:
|
||||
start_time = time.time()
|
||||
if not await session_exists(sid, self.file_store):
|
||||
return None
|
||||
|
||||
async with self._conversations_lock:
|
||||
# Check if we have an active conversation we can reuse
|
||||
if sid in self._active_conversations:
|
||||
conversation, count = self._active_conversations[sid]
|
||||
self._active_conversations[sid] = (conversation, count + 1)
|
||||
logger.info(f'Reusing active conversation {sid}')
|
||||
return conversation
|
||||
|
||||
# Check if we have a detached conversation we can reuse
|
||||
if sid in self._detached_conversations:
|
||||
conversation, _ = self._detached_conversations.pop(sid)
|
||||
self._active_conversations[sid] = (conversation, 1)
|
||||
logger.info(f'Reusing detached conversation {sid}')
|
||||
return conversation
|
||||
|
||||
# Create new conversation if none exists
|
||||
c = Conversation(sid, file_store=self.file_store, config=self.config)
|
||||
try:
|
||||
await c.connect()
|
||||
except AgentRuntimeUnavailableError as e:
|
||||
logger.error(f'Error connecting to conversation {c.sid}: {e}')
|
||||
await c.disconnect()
|
||||
return None
|
||||
end_time = time.time()
|
||||
logger.info(
|
||||
f'Conversation {c.sid} connected in {end_time - start_time} seconds'
|
||||
)
|
||||
self._active_conversations[sid] = (c, 1)
|
||||
return c
|
||||
|
||||
async def join_conversation(
|
||||
self, sid: str, connection_id: str, settings: Settings, user_id: str | None
|
||||
):
|
||||
logger.info(f'join_conversation:{sid}:{connection_id}')
|
||||
await self.sio.enter_room(connection_id, ROOM_KEY.format(sid=sid))
|
||||
self._local_connection_id_to_session_id[connection_id] = sid
|
||||
event_stream = await self._get_event_stream(sid)
|
||||
if not event_stream:
|
||||
return await self.maybe_start_agent_loop(sid, settings, user_id)
|
||||
return event_stream
|
||||
|
||||
async def detach_from_conversation(self, conversation: Conversation):
|
||||
sid = conversation.sid
|
||||
async with self._conversations_lock:
|
||||
if sid in self._active_conversations:
|
||||
conv, count = self._active_conversations[sid]
|
||||
if count > 1:
|
||||
self._active_conversations[sid] = (conv, count - 1)
|
||||
return
|
||||
else:
|
||||
self._active_conversations.pop(sid)
|
||||
self._detached_conversations[sid] = (conversation, time.time())
|
||||
|
||||
async def _cleanup_stale(self):
|
||||
while should_continue():
|
||||
try:
|
||||
async with self._conversations_lock:
|
||||
# Create a list of items to process to avoid modifying dict during iteration
|
||||
items = list(self._detached_conversations.items())
|
||||
for sid, (conversation, detach_time) in items:
|
||||
await conversation.disconnect()
|
||||
self._detached_conversations.pop(sid, None)
|
||||
|
||||
close_threshold = time.time() - self.config.sandbox.close_delay
|
||||
running_loops = list(self._local_agent_loops_by_sid.items())
|
||||
running_loops.sort(key=lambda item: item[1].last_active_ts)
|
||||
sid_to_close: list[str] = []
|
||||
for sid, session in running_loops:
|
||||
state = session.agent_session.get_state()
|
||||
if session.last_active_ts < close_threshold and state not in [
|
||||
AgentState.RUNNING,
|
||||
None,
|
||||
]:
|
||||
sid_to_close.append(sid)
|
||||
|
||||
connections = await self.get_connections(
|
||||
filter_to_sids=set(sid_to_close)
|
||||
)
|
||||
connected_sids = {sid for _, sid in connections.items()}
|
||||
sid_to_close = [
|
||||
sid for sid in sid_to_close if sid not in connected_sids
|
||||
]
|
||||
await wait_all(self._close_session(sid) for sid in sid_to_close)
|
||||
await asyncio.sleep(_CLEANUP_INTERVAL)
|
||||
except asyncio.CancelledError:
|
||||
async with self._conversations_lock:
|
||||
for conversation, _ in self._detached_conversations.values():
|
||||
await conversation.disconnect()
|
||||
self._detached_conversations.clear()
|
||||
await wait_all(
|
||||
self._close_session(sid) for sid in self._local_agent_loops_by_sid
|
||||
)
|
||||
return
|
||||
except Exception as e:
|
||||
logger.warning(f'error_cleaning_stale: {str(e)}')
|
||||
await asyncio.sleep(_CLEANUP_INTERVAL)
|
||||
|
||||
async def get_running_agent_loops(
|
||||
self, user_id: str | None = None, filter_to_sids: set[str] | None = None
|
||||
) -> set[str]:
|
||||
"""Get the running session ids. If a user is supplied, then the results are limited to session ids for that user. If a set of filter_to_sids is supplied, then results are limited to these ids of interest."""
|
||||
items: Iterable[tuple[str, Session]] = self._local_agent_loops_by_sid.items()
|
||||
if filter_to_sids is not None:
|
||||
items = (item for item in items if item[0] in filter_to_sids)
|
||||
if user_id:
|
||||
items = (item for item in items if item[1].user_id == user_id)
|
||||
sids = {sid for sid, _ in items}
|
||||
return sids
|
||||
|
||||
async def get_connections(
|
||||
self, user_id: str | None = None, filter_to_sids: set[str] | None = None
|
||||
) -> dict[str, str]:
|
||||
connections = dict(**self._local_connection_id_to_session_id)
|
||||
if filter_to_sids is not None:
|
||||
connections = {
|
||||
connection_id: sid
|
||||
for connection_id, sid in connections.items()
|
||||
if sid in filter_to_sids
|
||||
}
|
||||
if user_id:
|
||||
for connection_id, sid in list(connections.items()):
|
||||
session = self._local_agent_loops_by_sid.get(sid)
|
||||
if not session or session.user_id != user_id:
|
||||
connections.pop(connection_id)
|
||||
return connections
|
||||
|
||||
async def maybe_start_agent_loop(
|
||||
self,
|
||||
sid: str,
|
||||
settings: Settings,
|
||||
user_id: str | None,
|
||||
initial_user_msg: MessageAction | None = None,
|
||||
) -> EventStream:
|
||||
logger.info(f'maybe_start_agent_loop:{sid}')
|
||||
session: Session | None = None
|
||||
if not await self.is_agent_loop_running(sid):
|
||||
logger.info(f'start_agent_loop:{sid}')
|
||||
|
||||
response_ids = await self.get_running_agent_loops(user_id)
|
||||
if len(response_ids) >= MAX_RUNNING_CONVERSATIONS:
|
||||
logger.info('too_many_sessions_for:{user_id}')
|
||||
# Order is not guaranteed, but response_ids tend to be in descending chronological order
|
||||
# By reversing, we are likely to pick the oldest (or at least an older) conversation
|
||||
session_id = next(iter(reversed(list(response_ids))))
|
||||
await self.close_session(session_id)
|
||||
|
||||
session = Session(
|
||||
sid=sid,
|
||||
file_store=self.file_store,
|
||||
config=self.config,
|
||||
sio=self.sio,
|
||||
user_id=user_id,
|
||||
)
|
||||
self._local_agent_loops_by_sid[sid] = session
|
||||
asyncio.create_task(session.initialize_agent(settings, initial_user_msg))
|
||||
|
||||
event_stream = await self._get_event_stream(sid)
|
||||
if not event_stream:
|
||||
logger.error(f'No event stream after starting agent loop: {sid}')
|
||||
raise RuntimeError(f'no_event_stream:{sid}')
|
||||
return event_stream
|
||||
|
||||
async def _get_event_stream(self, sid: str) -> EventStream | None:
|
||||
logger.info(f'_get_event_stream:{sid}')
|
||||
session = self._local_agent_loops_by_sid.get(sid)
|
||||
if session:
|
||||
logger.info(f'found_local_agent_loop:{sid}')
|
||||
return session.agent_session.event_stream
|
||||
return None
|
||||
|
||||
async def send_to_event_stream(self, connection_id: str, data: dict):
|
||||
# If there is a local session running, send to that
|
||||
sid = self._local_connection_id_to_session_id.get(connection_id)
|
||||
if not sid:
|
||||
raise RuntimeError(f'no_connected_session:{connection_id}')
|
||||
|
||||
session = self._local_agent_loops_by_sid.get(sid)
|
||||
if session:
|
||||
await session.dispatch(data)
|
||||
return
|
||||
|
||||
raise RuntimeError(f'no_connected_session:{connection_id}:{sid}')
|
||||
|
||||
async def disconnect_from_session(self, connection_id: str):
|
||||
sid = self._local_connection_id_to_session_id.pop(connection_id, None)
|
||||
logger.info(f'disconnect_from_session:{connection_id}:{sid}')
|
||||
if not sid:
|
||||
# This can occur if the init action was never run.
|
||||
logger.warning(f'disconnect_from_uninitialized_session:{connection_id}')
|
||||
return
|
||||
|
||||
async def close_session(self, sid: str):
|
||||
session = self._local_agent_loops_by_sid.get(sid)
|
||||
if session:
|
||||
await self._close_session(sid)
|
||||
|
||||
async def _close_session(self, sid: str):
|
||||
logger.info(f'_close_session:{sid}')
|
||||
|
||||
# Clear up local variables
|
||||
connection_ids_to_remove = list(
|
||||
connection_id
|
||||
for connection_id, conn_sid in self._local_connection_id_to_session_id.items()
|
||||
if sid == conn_sid
|
||||
)
|
||||
logger.info(f'removing connections: {connection_ids_to_remove}')
|
||||
for connnnection_id in connection_ids_to_remove:
|
||||
self._local_connection_id_to_session_id.pop(connnnection_id, None)
|
||||
|
||||
session = self._local_agent_loops_by_sid.pop(sid, None)
|
||||
if not session:
|
||||
logger.warning(f'no_session_to_close:{sid}')
|
||||
return
|
||||
|
||||
logger.info(f'closing_session:{session.sid}')
|
||||
await session.close()
|
||||
logger.info(f'closed_session:{session.sid}')
|
||||
|
||||
@classmethod
|
||||
def get_instance(
|
||||
cls,
|
||||
sio: socketio.AsyncServer,
|
||||
config: AppConfig,
|
||||
file_store: FileStore,
|
||||
) -> ConversationManager:
|
||||
return StandaloneConversationManager(sio, config, file_store)
|
||||
@@ -5,7 +5,6 @@ from pydantic import SecretStr
|
||||
from socketio.exceptions import ConnectionRefusedError
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.core.schema.agent import AgentState
|
||||
from openhands.events.action import (
|
||||
NullAction,
|
||||
)
|
||||
@@ -16,7 +15,7 @@ from openhands.events.observation.agent import AgentStateChangedObservation
|
||||
from openhands.events.serialization import event_to_dict
|
||||
from openhands.events.stream import AsyncEventStreamWrapper
|
||||
from openhands.server.routes.settings import ConversationStoreImpl, SettingsStoreImpl
|
||||
from openhands.server.shared import config, openhands_config, session_manager, sio
|
||||
from openhands.server.shared import config, conversation_manager, openhands_config, sio
|
||||
from openhands.server.types import AppMode
|
||||
|
||||
|
||||
@@ -70,7 +69,7 @@ async def connect(connection_id: str, environ, auth):
|
||||
'Settings not found', {'msg_id': 'CONFIGURATION$SETTINGS_NOT_FOUND'}
|
||||
)
|
||||
|
||||
event_stream = await session_manager.join_conversation(
|
||||
event_stream = await conversation_manager.join_conversation(
|
||||
conversation_id, connection_id, settings, user_id
|
||||
)
|
||||
|
||||
@@ -86,8 +85,6 @@ async def connect(connection_id: str, environ, auth):
|
||||
):
|
||||
continue
|
||||
elif isinstance(event, AgentStateChangedObservation):
|
||||
if event.agent_state == AgentState.INIT:
|
||||
await sio.emit('oh_event', event_to_dict(event), to=connection_id)
|
||||
agent_state_changed = event
|
||||
else:
|
||||
await sio.emit('oh_event', event_to_dict(event), to=connection_id)
|
||||
@@ -97,10 +94,10 @@ async def connect(connection_id: str, environ, auth):
|
||||
|
||||
@sio.event
|
||||
async def oh_action(connection_id: str, data: dict):
|
||||
await session_manager.send_to_event_stream(connection_id, data)
|
||||
await conversation_manager.send_to_event_stream(connection_id, data)
|
||||
|
||||
|
||||
@sio.event
|
||||
async def disconnect(connection_id: str):
|
||||
logger.info(f'sio:disconnect:{connection_id}')
|
||||
await session_manager.disconnect_from_session(connection_id)
|
||||
await conversation_manager.disconnect_from_session(connection_id)
|
||||
|
||||
@@ -147,7 +147,7 @@ class AttachConversationMiddleware(SessionMiddlewareInterface):
|
||||
Attach the user's session based on the provided authentication token.
|
||||
"""
|
||||
request.state.conversation = (
|
||||
await shared.session_manager.attach_to_conversation(request.state.sid)
|
||||
await shared.conversation_manager.attach_to_conversation(request.state.sid)
|
||||
)
|
||||
if not request.state.conversation:
|
||||
return JSONResponse(
|
||||
@@ -160,7 +160,7 @@ class AttachConversationMiddleware(SessionMiddlewareInterface):
|
||||
"""
|
||||
Detach the user's session.
|
||||
"""
|
||||
await shared.session_manager.detach_from_conversation(
|
||||
await shared.conversation_manager.detach_from_conversation(
|
||||
request.state.conversation
|
||||
)
|
||||
|
||||
|
||||
@@ -2,7 +2,6 @@ import uvicorn
|
||||
from fastapi import FastAPI, WebSocket
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.core.schema import ActionType
|
||||
from openhands.utils.shutdown_listener import should_continue
|
||||
|
||||
app = FastAPI()
|
||||
@@ -11,10 +10,6 @@ app = FastAPI()
|
||||
@app.websocket('/ws')
|
||||
async def websocket_endpoint(websocket: WebSocket):
|
||||
await websocket.accept()
|
||||
# send message to mock connection
|
||||
await websocket.send_json(
|
||||
{'action': ActionType.INIT, 'message': 'Control loop started.'}
|
||||
)
|
||||
|
||||
try:
|
||||
while should_continue():
|
||||
|
||||
@@ -7,12 +7,13 @@ from fastapi.responses import JSONResponse
|
||||
from pydantic import BaseModel
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.events.action.message import MessageAction
|
||||
from openhands.events.stream import EventStreamSubscriber
|
||||
from openhands.runtime import get_runtime_cls
|
||||
from openhands.server.auth import get_user_id
|
||||
from openhands.server.routes.settings import ConversationStoreImpl, SettingsStoreImpl
|
||||
from openhands.server.session.conversation_init_data import ConversationInitData
|
||||
from openhands.server.shared import config, session_manager
|
||||
from openhands.server.shared import config, conversation_manager
|
||||
from openhands.server.types import LLMAuthenticationError, MissingSettingsError
|
||||
from openhands.storage.data_models.conversation_info import ConversationInfo
|
||||
from openhands.storage.data_models.conversation_info_result_set import (
|
||||
@@ -34,6 +35,7 @@ class InitSessionRequest(BaseModel):
|
||||
github_token: str | None = None
|
||||
selected_repository: str | None = None
|
||||
initial_user_msg: str | None = None
|
||||
image_urls: list[str] | None = None
|
||||
|
||||
|
||||
async def _create_new_conversation(
|
||||
@@ -41,6 +43,7 @@ async def _create_new_conversation(
|
||||
token: str | None,
|
||||
selected_repository: str | None,
|
||||
initial_user_msg: str | None,
|
||||
image_urls: list[str] | None,
|
||||
):
|
||||
logger.info('Loading settings')
|
||||
settings_store = await SettingsStoreImpl.get_instance(config, user_id)
|
||||
@@ -94,8 +97,14 @@ async def _create_new_conversation(
|
||||
)
|
||||
|
||||
logger.info(f'Starting agent loop for conversation {conversation_id}')
|
||||
event_stream = await session_manager.maybe_start_agent_loop(
|
||||
conversation_id, conversation_init_data, user_id, initial_user_msg
|
||||
initial_message_action = None
|
||||
if initial_user_msg or image_urls:
|
||||
initial_message_action = MessageAction(
|
||||
content=initial_user_msg or '',
|
||||
image_urls=image_urls or [],
|
||||
)
|
||||
event_stream = await conversation_manager.maybe_start_agent_loop(
|
||||
conversation_id, conversation_init_data, user_id, initial_message_action
|
||||
)
|
||||
try:
|
||||
event_stream.subscribe(
|
||||
@@ -121,10 +130,16 @@ async def new_conversation(request: Request, data: InitSessionRequest):
|
||||
github_token = getattr(request.state, 'github_token', '') or data.github_token
|
||||
selected_repository = data.selected_repository
|
||||
initial_user_msg = data.initial_user_msg
|
||||
image_urls = data.image_urls or []
|
||||
|
||||
try:
|
||||
# Create conversation with initial message
|
||||
conversation_id = await _create_new_conversation(
|
||||
user_id, github_token, selected_repository, initial_user_msg
|
||||
user_id,
|
||||
github_token,
|
||||
selected_repository,
|
||||
initial_user_msg,
|
||||
image_urls,
|
||||
)
|
||||
|
||||
return JSONResponse(
|
||||
@@ -166,7 +181,7 @@ async def search_conversations(
|
||||
for conversation in conversation_metadata_result_set.results
|
||||
if hasattr(conversation, 'created_at')
|
||||
)
|
||||
running_conversations = await session_manager.get_running_agent_loops(
|
||||
running_conversations = await conversation_manager.get_running_agent_loops(
|
||||
get_user_id(request), set(conversation_ids)
|
||||
)
|
||||
result = ConversationInfoResultSet(
|
||||
@@ -191,7 +206,7 @@ async def get_conversation(
|
||||
)
|
||||
try:
|
||||
metadata = await conversation_store.get_metadata(conversation_id)
|
||||
is_running = await session_manager.is_agent_loop_running(conversation_id)
|
||||
is_running = await conversation_manager.is_agent_loop_running(conversation_id)
|
||||
conversation_info = await _get_conversation_info(metadata, is_running)
|
||||
return conversation_info
|
||||
except FileNotFoundError:
|
||||
@@ -225,9 +240,9 @@ async def delete_conversation(
|
||||
await conversation_store.get_metadata(conversation_id)
|
||||
except FileNotFoundError:
|
||||
return False
|
||||
is_running = await session_manager.is_agent_loop_running(conversation_id)
|
||||
is_running = await conversation_manager.is_agent_loop_running(conversation_id)
|
||||
if is_running:
|
||||
await session_manager.close_session(conversation_id)
|
||||
await conversation_manager.close_session(conversation_id)
|
||||
runtime_cls = get_runtime_cls(config.runtime)
|
||||
await runtime_cls.delete(conversation_id)
|
||||
await conversation_store.delete_metadata(conversation_id)
|
||||
|
||||
@@ -8,19 +8,12 @@ interruptions are recoverable.
|
||||
There are 3 main server side event handlers:
|
||||
|
||||
* `connect` - Invoked when a new connection to the server is established. (This may be via http or WebSocket)
|
||||
* `oh_action` - Invoked when a connected client sends an event (Such as `INIT` or a prompt for the Agent) -
|
||||
* `oh_action` - Invoked when a connected client sends an event (such as a prompt for the Agent) -
|
||||
this is distinct from the `oh_event` sent from the server to the client.
|
||||
* `disconnect` - Invoked when a connected client disconnects from the server.
|
||||
|
||||
## Init
|
||||
Each connection has a unique id, and when initially established, is not associated with any session. An
|
||||
`INIT` event must be sent to the server in order to attach a connection to a session. The `INIT` event
|
||||
may optionally include a GitHub token and a token to connect to an existing session. (Which may be running
|
||||
locally or may need to be hydrated). If no token is received as part of the init event, it is assumed a
|
||||
new session should be started.
|
||||
|
||||
## Disconnect
|
||||
The (manager)[manager.py] manages connections and sessions. Each session may have zero or more connections
|
||||
associated with it, managed by invocations of `INIT` and disconnect. When a session no longer has any
|
||||
associated with it. When a session no longer has any
|
||||
connections associated with it, after a set amount of time (determined by `config.sandbox.close_delay`),
|
||||
the session and runtime are passivated (So will need to be rehydrated to continue.)
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
from openhands.server.session.manager import SessionManager
|
||||
from openhands.server.session.session import Session
|
||||
|
||||
__all__ = ['Session', 'SessionManager']
|
||||
__all__ = ['Session']
|
||||
|
||||
@@ -9,8 +9,7 @@ from openhands.core.config import AgentConfig, AppConfig, LLMConfig
|
||||
from openhands.core.exceptions import AgentRuntimeUnavailableError
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.core.schema.agent import AgentState
|
||||
from openhands.events.action import ChangeAgentStateAction
|
||||
from openhands.events.action.message import MessageAction
|
||||
from openhands.events.action import ChangeAgentStateAction, MessageAction
|
||||
from openhands.events.event import EventSource
|
||||
from openhands.events.stream import EventStream
|
||||
from openhands.microagent import BaseMicroAgent
|
||||
@@ -72,7 +71,7 @@ class AgentSession:
|
||||
agent_configs: dict[str, AgentConfig] | None = None,
|
||||
github_token: str | None = None,
|
||||
selected_repository: str | None = None,
|
||||
initial_user_msg: str | None = None,
|
||||
initial_message: MessageAction | None = None,
|
||||
):
|
||||
"""Starts the Agent session
|
||||
Parameters:
|
||||
@@ -111,15 +110,17 @@ class AgentSession:
|
||||
agent_to_llm_config=agent_to_llm_config,
|
||||
agent_configs=agent_configs,
|
||||
)
|
||||
self.event_stream.add_event(
|
||||
ChangeAgentStateAction(AgentState.INIT), EventSource.ENVIRONMENT
|
||||
)
|
||||
|
||||
if initial_user_msg:
|
||||
if initial_message:
|
||||
self.event_stream.add_event(initial_message, EventSource.USER)
|
||||
self.event_stream.add_event(
|
||||
MessageAction(content=initial_user_msg), EventSource.USER
|
||||
ChangeAgentStateAction(AgentState.RUNNING), EventSource.ENVIRONMENT
|
||||
)
|
||||
|
||||
else:
|
||||
self.event_stream.add_event(
|
||||
ChangeAgentStateAction(AgentState.AWAITING_USER_INPUT),
|
||||
EventSource.ENVIRONMENT,
|
||||
)
|
||||
|
||||
self._starting = False
|
||||
|
||||
async def close(self):
|
||||
|
||||
@@ -11,6 +11,7 @@ from openhands.core.config import AppConfig
|
||||
from openhands.core.exceptions import AgentRuntimeUnavailableError
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.core.schema.agent import AgentState
|
||||
from openhands.events.action import MessageAction
|
||||
from openhands.events.stream import EventStream, session_exists
|
||||
from openhands.server.session.agent_session import WAIT_TIME_BEFORE_CLOSE
|
||||
from openhands.server.session.conversation import Conversation
|
||||
@@ -446,7 +447,7 @@ class SessionManager:
|
||||
sid: str,
|
||||
settings: Settings,
|
||||
user_id: str | None,
|
||||
initial_user_msg: str | None = None,
|
||||
initial_message: MessageAction | None = None,
|
||||
) -> EventStream:
|
||||
logger.info(f'maybe_start_agent_loop:{sid}')
|
||||
session: Session | None = None
|
||||
@@ -469,7 +470,7 @@ class SessionManager:
|
||||
user_id=user_id,
|
||||
)
|
||||
self._local_agent_loops_by_sid[sid] = session
|
||||
asyncio.create_task(session.initialize_agent(settings, initial_user_msg))
|
||||
asyncio.create_task(session.initialize_agent(settings, initial_message))
|
||||
|
||||
event_stream = await self._get_event_stream(sid)
|
||||
if not event_stream:
|
||||
|
||||
@@ -74,7 +74,9 @@ class Session:
|
||||
self.is_alive = False
|
||||
await self.agent_session.close()
|
||||
|
||||
async def initialize_agent(self, settings: Settings, initial_user_msg: str | None):
|
||||
async def initialize_agent(
|
||||
self, settings: Settings, initial_message: MessageAction | None
|
||||
):
|
||||
self.agent_session.event_stream.add_event(
|
||||
AgentStateChangedObservation('', AgentState.LOADING),
|
||||
EventSource.ENVIRONMENT,
|
||||
@@ -122,7 +124,7 @@ class Session:
|
||||
agent_configs=self.config.get_agent_configs(),
|
||||
github_token=github_token,
|
||||
selected_repository=selected_repository,
|
||||
initial_user_msg=initial_user_msg,
|
||||
initial_message=initial_message,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(f'Error creating agent_session: {e}')
|
||||
|
||||
@@ -5,8 +5,11 @@ from dotenv import load_dotenv
|
||||
|
||||
from openhands.core.config import load_app_config
|
||||
from openhands.server.config.openhands_config import load_openhands_config
|
||||
from openhands.server.session import SessionManager
|
||||
from openhands.server.conversation_manager.conversation_manager import (
|
||||
ConversationManager,
|
||||
)
|
||||
from openhands.storage import get_file_store
|
||||
from openhands.utils.import_utils import get_impl
|
||||
|
||||
load_dotenv()
|
||||
|
||||
@@ -27,4 +30,8 @@ sio = socketio.AsyncServer(
|
||||
async_mode='asgi', cors_allowed_origins='*', client_manager=client_manager
|
||||
)
|
||||
|
||||
session_manager = SessionManager(sio, config, file_store)
|
||||
ConversationManagerImpl = get_impl(
|
||||
ConversationManager, # type: ignore
|
||||
openhands_config.conversation_manager_class,
|
||||
)
|
||||
conversation_manager = ConversationManagerImpl.get_instance(sio, config, file_store)
|
||||
|
||||
@@ -40,6 +40,8 @@ class FileConversationStore(ConversationStore):
|
||||
|
||||
# Temp: force int to str to stop pydandic being, well... pedantic
|
||||
json_obj = json.loads(json_str)
|
||||
if 'created_at' not in json_obj:
|
||||
raise FileNotFoundError(path)
|
||||
if isinstance(json_obj.get('github_user_id'), int):
|
||||
json_obj['github_user_id'] = str(json_obj.get('github_user_id'))
|
||||
|
||||
|
||||
Generated
+32
-21
@@ -1,4 +1,4 @@
|
||||
# This file is automatically @generated by Poetry 1.8.4 and should not be changed by hand.
|
||||
# This file is automatically @generated by Poetry 1.8.5 and should not be changed by hand.
|
||||
|
||||
[[package]]
|
||||
name = "aiohappyeyeballs"
|
||||
@@ -1108,6 +1108,20 @@ humanfriendly = ">=9.1"
|
||||
[package.extras]
|
||||
cron = ["capturer (>=2.4)"]
|
||||
|
||||
[[package]]
|
||||
name = "colormath"
|
||||
version = "3.0.0"
|
||||
description = "Color math and conversion library."
|
||||
optional = false
|
||||
python-versions = "*"
|
||||
files = [
|
||||
{file = "colormath-3.0.0.tar.gz", hash = "sha256:3d4605af344527da0e4f9f504fad7ddbebda35322c566a6c72e28edb1ff31217"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
networkx = ">=2.0"
|
||||
numpy = "*"
|
||||
|
||||
[[package]]
|
||||
name = "comm"
|
||||
version = "0.2.2"
|
||||
@@ -3944,19 +3958,19 @@ pydantic = ">=1.10"
|
||||
|
||||
[[package]]
|
||||
name = "llama-index"
|
||||
version = "0.12.12"
|
||||
version = "0.12.13"
|
||||
description = "Interface between LLMs and your data"
|
||||
optional = false
|
||||
python-versions = "<4.0,>=3.9"
|
||||
files = [
|
||||
{file = "llama_index-0.12.12-py3-none-any.whl", hash = "sha256:208f77dba5fd8268cacd3d56ec3ee33b0001d5b6ec623c5b91c755af7b08cfae"},
|
||||
{file = "llama_index-0.12.12.tar.gz", hash = "sha256:d4e475726e342b1178736ae3ed93336fe114605e86431b6dfcb454a9e1f26e72"},
|
||||
{file = "llama_index-0.12.13-py3-none-any.whl", hash = "sha256:0b285aa451ced6bd8da40df99068ac96badf8b5725c4edc29f2bce4da2ffd8bc"},
|
||||
{file = "llama_index-0.12.13.tar.gz", hash = "sha256:1e39a397dcc51dabe280c121fd8d5451a6a84595233a8b26caa54d9b7ecf9ffc"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
llama-index-agent-openai = ">=0.4.0,<0.5.0"
|
||||
llama-index-cli = ">=0.4.0,<0.5.0"
|
||||
llama-index-core = ">=0.12.12,<0.13.0"
|
||||
llama-index-core = ">=0.12.13,<0.13.0"
|
||||
llama-index-embeddings-openai = ">=0.3.0,<0.4.0"
|
||||
llama-index-indices-managed-llama-cloud = ">=0.4.0"
|
||||
llama-index-llms-openai = ">=0.3.0,<0.4.0"
|
||||
@@ -4001,13 +4015,13 @@ llama-index-llms-openai = ">=0.3.0,<0.4.0"
|
||||
|
||||
[[package]]
|
||||
name = "llama-index-core"
|
||||
version = "0.12.12"
|
||||
version = "0.12.13"
|
||||
description = "Interface between LLMs and your data"
|
||||
optional = false
|
||||
python-versions = "<4.0,>=3.9"
|
||||
files = [
|
||||
{file = "llama_index_core-0.12.12-py3-none-any.whl", hash = "sha256:cea491e87f65e6b775b5aef95720de302b85af1bdc67d779c4b09170a30e5b98"},
|
||||
{file = "llama_index_core-0.12.12.tar.gz", hash = "sha256:068b755bbc681731336e822f5977d7608585e8f759c6293ebd812e2659316a37"},
|
||||
{file = "llama_index_core-0.12.13-py3-none-any.whl", hash = "sha256:9708bb594bbddffd6ff0767242e49d8978d1ba60a2e62e071d9d123ad2f17e6f"},
|
||||
{file = "llama_index_core-0.12.13.tar.gz", hash = "sha256:77af0161246ce1de38efc17cb6438dfff9e9558af00bcfac7dd4d0b7325efa4b"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
@@ -5469,6 +5483,7 @@ description = "Nvidia JIT LTO Library"
|
||||
optional = false
|
||||
python-versions = ">=3"
|
||||
files = [
|
||||
{file = "nvidia_nvjitlink_cu12-12.4.127-py3-none-manylinux2014_aarch64.whl", hash = "sha256:4abe7fef64914ccfa909bc2ba39739670ecc9e820c83ccc7a6ed414122599b83"},
|
||||
{file = "nvidia_nvjitlink_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl", hash = "sha256:06b3b9b25bf3f8af351d664978ca26a16d2c5127dbd53c0497e28d1fb9611d57"},
|
||||
{file = "nvidia_nvjitlink_cu12-12.4.127-py3-none-win_amd64.whl", hash = "sha256:fd9020c501d27d135f983c6d3e244b197a7ccad769e34df53a42e276b0e25fa1"},
|
||||
]
|
||||
@@ -5581,15 +5596,17 @@ realtime = ["websockets (>=13,<15)"]
|
||||
|
||||
[[package]]
|
||||
name = "openhands-aci"
|
||||
version = "0.1.8"
|
||||
version = "0.1.9"
|
||||
description = "An Agent-Computer Interface (ACI) designed for software development agents OpenHands."
|
||||
optional = false
|
||||
python-versions = "^3.12"
|
||||
files = []
|
||||
develop = false
|
||||
python-versions = "<4.0,>=3.12"
|
||||
files = [
|
||||
{file = "openhands_aci-0.1.9-py3-none-any.whl", hash = "sha256:62af189878db046aa98475a41fa01200efd5ddf1db8a435c38da3d4ad32cb11a"},
|
||||
{file = "openhands_aci-0.1.9.tar.gz", hash = "sha256:690d33d355a3e4111f52861dbb96ff766b5a268202324a87c94ba67b628a63b1"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
diskcache = "^5.6.3"
|
||||
diskcache = ">=5.6.3,<6.0.0"
|
||||
flake8 = "*"
|
||||
gitpython = "*"
|
||||
grep-ast = "0.3.3"
|
||||
@@ -5599,13 +5616,7 @@ numpy = "*"
|
||||
pandas = "*"
|
||||
scipy = "*"
|
||||
tree-sitter = "0.21.3"
|
||||
whatthepatch = "^1.0.6"
|
||||
|
||||
[package.source]
|
||||
type = "git"
|
||||
url = "https://github.com/All-Hands-AI/openhands-aci.git"
|
||||
reference = "fix-find-show-only-hidden-subpaths"
|
||||
resolved_reference = "910e8c470aff0e496bf262bc673c7ee7b4531159"
|
||||
whatthepatch = ">=1.0.6,<2.0.0"
|
||||
|
||||
[[package]]
|
||||
name = "opentelemetry-api"
|
||||
@@ -10119,4 +10130,4 @@ testing = ["coverage[toml]", "zope.event", "zope.testing"]
|
||||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = "^3.12"
|
||||
content-hash = "6b74056694bdc84a4583c2f93a5b218f15688827cb59e289eb83331045a1582e"
|
||||
content-hash = "fbca4b2ca0fe2d1d3cac46164c0c1eb9e468dc6f6bc7165e9a3d62ea9f25d801"
|
||||
|
||||
+3
-1
@@ -65,7 +65,7 @@ runloop-api-client = "0.13.0"
|
||||
libtmux = ">=0.37,<0.40"
|
||||
pygithub = "^2.5.0"
|
||||
joblib = "*"
|
||||
openhands-aci = "0.1.8"
|
||||
openhands-aci = "0.1.9"
|
||||
python-socketio = "^5.11.4"
|
||||
redis = "^5.2.0"
|
||||
sse-starlette = "^2.1.3"
|
||||
@@ -101,6 +101,7 @@ reportlab = "*"
|
||||
[tool.coverage.run]
|
||||
concurrency = ["gevent"]
|
||||
|
||||
|
||||
[tool.poetry.group.runtime.dependencies]
|
||||
jupyterlab = "*"
|
||||
notebook = "*"
|
||||
@@ -129,6 +130,7 @@ ignore = ["D1"]
|
||||
[tool.ruff.lint.pydocstyle]
|
||||
convention = "google"
|
||||
|
||||
|
||||
[tool.poetry.group.evaluation.dependencies]
|
||||
streamlit = "*"
|
||||
whatthepatch = "*"
|
||||
|
||||
@@ -500,6 +500,111 @@ def test_send_pull_request_with_reviewer(
|
||||
assert result == 'https://github.com/test-owner/test-repo/pull/1'
|
||||
|
||||
|
||||
@patch('subprocess.run')
|
||||
@patch('requests.post')
|
||||
@patch('requests.get')
|
||||
def test_send_pull_request_target_branch_with_fork(
|
||||
mock_get, mock_post, mock_run, mock_github_issue, mock_output_dir
|
||||
):
|
||||
"""Test that target_branch works correctly when using a fork."""
|
||||
repo_path = os.path.join(mock_output_dir, 'repo')
|
||||
fork_owner = 'fork-owner'
|
||||
target_branch = 'custom-target'
|
||||
|
||||
# Mock API responses
|
||||
mock_get.side_effect = [
|
||||
MagicMock(status_code=404), # Branch doesn't exist
|
||||
MagicMock(status_code=200), # Target branch exists
|
||||
]
|
||||
|
||||
mock_post.return_value.json.return_value = {
|
||||
'html_url': 'https://github.com/test-owner/test-repo/pull/1'
|
||||
}
|
||||
|
||||
# Mock subprocess.run calls
|
||||
mock_run.side_effect = [
|
||||
MagicMock(returncode=0), # git checkout -b
|
||||
MagicMock(returncode=0), # git push
|
||||
]
|
||||
|
||||
# Call the function with fork_owner and target_branch
|
||||
result = send_pull_request(
|
||||
github_issue=mock_github_issue,
|
||||
github_token='test-token',
|
||||
github_username='test-user',
|
||||
patch_dir=repo_path,
|
||||
pr_type='ready',
|
||||
fork_owner=fork_owner,
|
||||
target_branch=target_branch,
|
||||
)
|
||||
|
||||
# Assert API calls
|
||||
assert mock_get.call_count == 2
|
||||
|
||||
# Verify target branch was checked in original repo, not fork
|
||||
target_branch_check = mock_get.call_args_list[1]
|
||||
assert target_branch_check[0][0] == f'https://api.github.com/repos/test-owner/test-repo/branches/{target_branch}'
|
||||
|
||||
# Check PR creation
|
||||
mock_post.assert_called_once()
|
||||
post_data = mock_post.call_args[1]['json']
|
||||
assert post_data['base'] == target_branch # PR should target the specified branch
|
||||
assert post_data['head'] == 'openhands-fix-issue-42' # Branch name should be standard
|
||||
|
||||
# Check that push was to fork
|
||||
push_call = mock_run.call_args_list[1]
|
||||
assert f'https://test-user:test-token@github.com/{fork_owner}/test-repo.git' in str(push_call)
|
||||
|
||||
|
||||
@patch('subprocess.run')
|
||||
@patch('requests.post')
|
||||
@patch('requests.get')
|
||||
def test_send_pull_request_target_branch_with_additional_message(
|
||||
mock_get, mock_post, mock_run, mock_github_issue, mock_output_dir
|
||||
):
|
||||
"""Test that target_branch works correctly with additional PR message."""
|
||||
repo_path = os.path.join(mock_output_dir, 'repo')
|
||||
target_branch = 'feature-branch'
|
||||
additional_message = 'Additional PR context'
|
||||
|
||||
# Mock API responses
|
||||
mock_get.side_effect = [
|
||||
MagicMock(status_code=404), # Branch doesn't exist
|
||||
MagicMock(status_code=200), # Target branch exists
|
||||
]
|
||||
|
||||
mock_post.return_value.json.return_value = {
|
||||
'html_url': 'https://github.com/test-owner/test-repo/pull/1'
|
||||
}
|
||||
|
||||
# Mock subprocess.run calls
|
||||
mock_run.side_effect = [
|
||||
MagicMock(returncode=0), # git checkout -b
|
||||
MagicMock(returncode=0), # git push
|
||||
]
|
||||
|
||||
# Call the function with target_branch and additional_message
|
||||
result = send_pull_request(
|
||||
github_issue=mock_github_issue,
|
||||
github_token='test-token',
|
||||
github_username='test-user',
|
||||
patch_dir=repo_path,
|
||||
pr_type='ready',
|
||||
target_branch=target_branch,
|
||||
additional_message=additional_message,
|
||||
)
|
||||
|
||||
# Assert API calls
|
||||
assert mock_get.call_count == 2
|
||||
|
||||
# Check PR creation
|
||||
mock_post.assert_called_once()
|
||||
post_data = mock_post.call_args[1]['json']
|
||||
assert post_data['base'] == target_branch
|
||||
assert additional_message in post_data['body']
|
||||
assert 'This pull request fixes #42' in post_data['body']
|
||||
|
||||
|
||||
@patch('requests.get')
|
||||
def test_send_pull_request_invalid_target_branch(
|
||||
mock_get, mock_github_issue, mock_output_dir
|
||||
|
||||
@@ -40,7 +40,7 @@ def _patch_store():
|
||||
MagicMock(return_value=file_store),
|
||||
):
|
||||
with patch(
|
||||
'openhands.server.routes.manage_conversations.session_manager.file_store',
|
||||
'openhands.server.routes.manage_conversations.conversation_manager.file_store',
|
||||
file_store,
|
||||
):
|
||||
yield
|
||||
|
||||
@@ -3,12 +3,6 @@ from unittest.mock import patch
|
||||
from openhands.core.config import AppConfig
|
||||
|
||||
|
||||
# Mock the SessionManager to avoid asyncio issues
|
||||
class MockSessionManager:
|
||||
def __init__(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
|
||||
# Mock StaticFiles
|
||||
class MockStaticFiles:
|
||||
def __init__(self, *args, **kwargs):
|
||||
@@ -17,7 +11,6 @@ class MockStaticFiles:
|
||||
|
||||
# Patch necessary components before importing from listen
|
||||
with (
|
||||
patch('openhands.server.session.SessionManager', MockSessionManager),
|
||||
patch('fastapi.staticfiles.StaticFiles', MockStaticFiles),
|
||||
):
|
||||
from openhands.server.file_config import (
|
||||
|
||||
@@ -1,297 +0,0 @@
|
||||
import asyncio
|
||||
import json
|
||||
from dataclasses import dataclass
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from openhands.core.config.app_config import AppConfig
|
||||
from openhands.server.session.conversation_init_data import ConversationInitData
|
||||
from openhands.server.session.manager import SessionManager
|
||||
from openhands.storage.memory import InMemoryFileStore
|
||||
|
||||
|
||||
@dataclass
|
||||
class GetMessageMock:
|
||||
message: dict | None
|
||||
sleep_time: int = 0.01
|
||||
|
||||
async def get_message(self, **kwargs):
|
||||
await asyncio.sleep(self.sleep_time)
|
||||
return {'data': json.dumps(self.message)}
|
||||
|
||||
|
||||
def get_mock_sio(get_message: GetMessageMock | None = None):
|
||||
sio = MagicMock()
|
||||
sio.enter_room = AsyncMock()
|
||||
sio.manager.redis = MagicMock()
|
||||
sio.manager.redis.publish = AsyncMock()
|
||||
pubsub = AsyncMock()
|
||||
pubsub.get_message = (get_message or GetMessageMock(None)).get_message
|
||||
sio.manager.redis.pubsub.return_value = pubsub
|
||||
return sio
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_session_not_running_in_cluster():
|
||||
sio = get_mock_sio()
|
||||
id = uuid4()
|
||||
with (
|
||||
patch('openhands.server.session.manager._REDIS_POLL_TIMEOUT', 0.01),
|
||||
patch('openhands.server.session.manager.uuid4', MagicMock(return_value=id)),
|
||||
):
|
||||
async with SessionManager(
|
||||
sio, AppConfig(), InMemoryFileStore()
|
||||
) as session_manager:
|
||||
result = await session_manager._get_running_agent_loops_remotely(
|
||||
filter_to_sids={'non-existant-session'}
|
||||
)
|
||||
assert result == set()
|
||||
assert sio.manager.redis.publish.await_count == 1
|
||||
sio.manager.redis.publish.assert_called_once_with(
|
||||
'session_msg',
|
||||
'{"query_id": "'
|
||||
+ str(id)
|
||||
+ '", "message_type": "running_agent_loops_query", "filter_to_sids": ["non-existant-session"]}',
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_running_agent_loops_remotely():
|
||||
id = uuid4()
|
||||
sio = get_mock_sio(
|
||||
GetMessageMock(
|
||||
{
|
||||
'query_id': str(id),
|
||||
'sids': ['existing-session'],
|
||||
'message_type': 'running_agent_loops_response',
|
||||
}
|
||||
)
|
||||
)
|
||||
with (
|
||||
patch('openhands.server.session.manager._REDIS_POLL_TIMEOUT', 0.1),
|
||||
patch('openhands.server.session.manager.uuid4', MagicMock(return_value=id)),
|
||||
):
|
||||
async with SessionManager(
|
||||
sio, AppConfig(), InMemoryFileStore()
|
||||
) as session_manager:
|
||||
result = await session_manager._get_running_agent_loops_remotely(
|
||||
1, {'existing-session'}
|
||||
)
|
||||
assert result == {'existing-session'}
|
||||
assert sio.manager.redis.publish.await_count == 1
|
||||
sio.manager.redis.publish.assert_called_once_with(
|
||||
'session_msg',
|
||||
'{"query_id": "'
|
||||
+ str(id)
|
||||
+ '", "message_type": "running_agent_loops_query", "user_id": 1, "filter_to_sids": ["existing-session"]}',
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_init_new_local_session():
|
||||
session_instance = AsyncMock()
|
||||
session_instance.agent_session = MagicMock()
|
||||
mock_session = MagicMock()
|
||||
mock_session.return_value = session_instance
|
||||
sio = get_mock_sio()
|
||||
get_running_agent_loops_mock = AsyncMock()
|
||||
get_running_agent_loops_mock.return_value = set()
|
||||
with (
|
||||
patch('openhands.server.session.manager.Session', mock_session),
|
||||
patch('openhands.server.session.manager._REDIS_POLL_TIMEOUT', 0.1),
|
||||
patch(
|
||||
'openhands.server.session.manager.SessionManager._redis_subscribe',
|
||||
AsyncMock(),
|
||||
),
|
||||
patch(
|
||||
'openhands.server.session.manager.SessionManager.get_running_agent_loops',
|
||||
get_running_agent_loops_mock,
|
||||
),
|
||||
):
|
||||
async with SessionManager(
|
||||
sio, AppConfig(), InMemoryFileStore()
|
||||
) as session_manager:
|
||||
await session_manager.maybe_start_agent_loop(
|
||||
'new-session-id', ConversationInitData(), 1
|
||||
)
|
||||
await session_manager.join_conversation(
|
||||
'new-session-id', 'new-session-id', ConversationInitData(), 1
|
||||
)
|
||||
assert session_instance.initialize_agent.call_count == 1
|
||||
assert sio.enter_room.await_count == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_join_local_session():
|
||||
session_instance = AsyncMock()
|
||||
session_instance.agent_session = MagicMock()
|
||||
mock_session = MagicMock()
|
||||
mock_session.return_value = session_instance
|
||||
sio = get_mock_sio()
|
||||
get_running_agent_loops_mock = AsyncMock()
|
||||
get_running_agent_loops_mock.return_value = set()
|
||||
with (
|
||||
patch('openhands.server.session.manager.Session', mock_session),
|
||||
patch('openhands.server.session.manager._REDIS_POLL_TIMEOUT', 0.01),
|
||||
patch(
|
||||
'openhands.server.session.manager.SessionManager._redis_subscribe',
|
||||
AsyncMock(),
|
||||
),
|
||||
patch(
|
||||
'openhands.server.session.manager.SessionManager.get_running_agent_loops',
|
||||
get_running_agent_loops_mock,
|
||||
),
|
||||
):
|
||||
async with SessionManager(
|
||||
sio, AppConfig(), InMemoryFileStore()
|
||||
) as session_manager:
|
||||
await session_manager.maybe_start_agent_loop(
|
||||
'new-session-id', ConversationInitData(), None
|
||||
)
|
||||
await session_manager.join_conversation(
|
||||
'new-session-id', 'new-session-id', ConversationInitData(), None
|
||||
)
|
||||
await session_manager.join_conversation(
|
||||
'new-session-id', 'new-session-id', ConversationInitData(), None
|
||||
)
|
||||
assert session_instance.initialize_agent.call_count == 1
|
||||
assert sio.enter_room.await_count == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_join_cluster_session():
|
||||
session_instance = AsyncMock()
|
||||
session_instance.agent_session = MagicMock()
|
||||
mock_session = MagicMock()
|
||||
mock_session.return_value = session_instance
|
||||
sio = get_mock_sio()
|
||||
get_running_agent_loops_mock = AsyncMock()
|
||||
get_running_agent_loops_mock.return_value = {'new-session-id'}
|
||||
with (
|
||||
patch('openhands.server.session.manager.Session', mock_session),
|
||||
patch('openhands.server.session.manager._REDIS_POLL_TIMEOUT', 0.01),
|
||||
patch(
|
||||
'openhands.server.session.manager.SessionManager._redis_subscribe',
|
||||
AsyncMock(),
|
||||
),
|
||||
patch(
|
||||
'openhands.server.session.manager.SessionManager._get_running_agent_loops_remotely',
|
||||
get_running_agent_loops_mock,
|
||||
),
|
||||
):
|
||||
async with SessionManager(
|
||||
sio, AppConfig(), InMemoryFileStore()
|
||||
) as session_manager:
|
||||
await session_manager.join_conversation(
|
||||
'new-session-id', 'new-session-id', ConversationInitData(), 1
|
||||
)
|
||||
assert session_instance.initialize_agent.call_count == 0
|
||||
assert sio.enter_room.await_count == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_to_local_event_stream():
|
||||
session_instance = AsyncMock()
|
||||
session_instance.agent_session = MagicMock()
|
||||
mock_session = MagicMock()
|
||||
mock_session.return_value = session_instance
|
||||
sio = get_mock_sio()
|
||||
get_running_agent_loops_mock = AsyncMock()
|
||||
get_running_agent_loops_mock.return_value = set()
|
||||
with (
|
||||
patch('openhands.server.session.manager.Session', mock_session),
|
||||
patch('openhands.server.session.manager._REDIS_POLL_TIMEOUT', 0.01),
|
||||
patch(
|
||||
'openhands.server.session.manager.SessionManager._redis_subscribe',
|
||||
AsyncMock(),
|
||||
),
|
||||
patch(
|
||||
'openhands.server.session.manager.SessionManager.get_running_agent_loops',
|
||||
get_running_agent_loops_mock,
|
||||
),
|
||||
):
|
||||
async with SessionManager(
|
||||
sio, AppConfig(), InMemoryFileStore()
|
||||
) as session_manager:
|
||||
await session_manager.maybe_start_agent_loop(
|
||||
'new-session-id', ConversationInitData(), 1
|
||||
)
|
||||
await session_manager.join_conversation(
|
||||
'new-session-id', 'connection-id', ConversationInitData(), 1
|
||||
)
|
||||
await session_manager.send_to_event_stream(
|
||||
'connection-id', {'event_type': 'some_event'}
|
||||
)
|
||||
session_instance.dispatch.assert_called_once_with({'event_type': 'some_event'})
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_to_cluster_event_stream():
|
||||
session_instance = AsyncMock()
|
||||
session_instance.agent_session = MagicMock()
|
||||
mock_session = MagicMock()
|
||||
mock_session.return_value = session_instance
|
||||
sio = get_mock_sio()
|
||||
get_running_agent_loops_mock = AsyncMock()
|
||||
get_running_agent_loops_mock.return_value = {'new-session-id'}
|
||||
with (
|
||||
patch('openhands.server.session.manager.Session', mock_session),
|
||||
patch('openhands.server.session.manager._REDIS_POLL_TIMEOUT', 0.01),
|
||||
patch(
|
||||
'openhands.server.session.manager.SessionManager._redis_subscribe',
|
||||
AsyncMock(),
|
||||
),
|
||||
patch(
|
||||
'openhands.server.session.manager.SessionManager._get_running_agent_loops_remotely',
|
||||
get_running_agent_loops_mock,
|
||||
),
|
||||
):
|
||||
async with SessionManager(
|
||||
sio, AppConfig(), InMemoryFileStore()
|
||||
) as session_manager:
|
||||
await session_manager.join_conversation(
|
||||
'new-session-id', 'connection-id', ConversationInitData(), 1
|
||||
)
|
||||
await session_manager.send_to_event_stream(
|
||||
'connection-id', {'event_type': 'some_event'}
|
||||
)
|
||||
assert sio.manager.redis.publish.await_count == 1
|
||||
sio.manager.redis.publish.assert_called_once_with(
|
||||
'session_msg',
|
||||
'{"sid": "new-session-id", "message_type": "event", "data": {"event_type": "some_event"}}',
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cleanup_session_connections():
|
||||
sio = get_mock_sio()
|
||||
with (
|
||||
patch('openhands.server.session.manager._REDIS_POLL_TIMEOUT', 0.01),
|
||||
patch(
|
||||
'openhands.server.session.manager.SessionManager._redis_subscribe',
|
||||
AsyncMock(),
|
||||
),
|
||||
):
|
||||
async with SessionManager(
|
||||
sio, AppConfig(), InMemoryFileStore()
|
||||
) as session_manager:
|
||||
session_manager._local_connection_id_to_session_id.update(
|
||||
{
|
||||
'conn1': 'session1',
|
||||
'conn2': 'session1',
|
||||
'conn3': 'session2',
|
||||
'conn4': 'session2',
|
||||
}
|
||||
)
|
||||
|
||||
await session_manager._close_session('session1')
|
||||
|
||||
remaining_connections = session_manager._local_connection_id_to_session_id
|
||||
assert 'conn1' not in remaining_connections
|
||||
assert 'conn2' not in remaining_connections
|
||||
assert 'conn3' in remaining_connections
|
||||
assert 'conn4' in remaining_connections
|
||||
assert remaining_connections['conn3'] == 'session2'
|
||||
assert remaining_connections['conn4'] == 'session2'
|
||||
@@ -0,0 +1,161 @@
|
||||
import asyncio
|
||||
import json
|
||||
from dataclasses import dataclass
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from openhands.core.config.app_config import AppConfig
|
||||
from openhands.server.conversation_manager.standalone_conversation_manager import (
|
||||
StandaloneConversationManager,
|
||||
)
|
||||
from openhands.server.session.conversation_init_data import ConversationInitData
|
||||
from openhands.storage.memory import InMemoryFileStore
|
||||
|
||||
|
||||
@dataclass
|
||||
class GetMessageMock:
|
||||
message: dict | None
|
||||
sleep_time: int = 0.01
|
||||
|
||||
async def get_message(self, **kwargs):
|
||||
await asyncio.sleep(self.sleep_time)
|
||||
return {'data': json.dumps(self.message)}
|
||||
|
||||
|
||||
def get_mock_sio(get_message: GetMessageMock | None = None):
|
||||
sio = MagicMock()
|
||||
sio.enter_room = AsyncMock()
|
||||
sio.manager.redis = MagicMock()
|
||||
sio.manager.redis.publish = AsyncMock()
|
||||
pubsub = AsyncMock()
|
||||
pubsub.get_message = (get_message or GetMessageMock(None)).get_message
|
||||
sio.manager.redis.pubsub.return_value = pubsub
|
||||
return sio
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_init_new_local_session():
|
||||
session_instance = AsyncMock()
|
||||
session_instance.agent_session = MagicMock()
|
||||
mock_session = MagicMock()
|
||||
mock_session.return_value = session_instance
|
||||
sio = get_mock_sio()
|
||||
get_running_agent_loops_mock = AsyncMock()
|
||||
get_running_agent_loops_mock.return_value = set()
|
||||
with (
|
||||
patch(
|
||||
'openhands.server.conversation_manager.standalone_conversation_manager.Session',
|
||||
mock_session,
|
||||
),
|
||||
patch(
|
||||
'openhands.server.conversation_manager.standalone_conversation_manager.StandaloneConversationManager.get_running_agent_loops',
|
||||
get_running_agent_loops_mock,
|
||||
),
|
||||
):
|
||||
async with StandaloneConversationManager(
|
||||
sio, AppConfig(), InMemoryFileStore()
|
||||
) as conversation_manager:
|
||||
await conversation_manager.maybe_start_agent_loop(
|
||||
'new-session-id', ConversationInitData(), 1
|
||||
)
|
||||
await conversation_manager.join_conversation(
|
||||
'new-session-id', 'new-session-id', ConversationInitData(), 1
|
||||
)
|
||||
assert session_instance.initialize_agent.call_count == 1
|
||||
assert sio.enter_room.await_count == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_join_local_session():
|
||||
session_instance = AsyncMock()
|
||||
session_instance.agent_session = MagicMock()
|
||||
mock_session = MagicMock()
|
||||
mock_session.return_value = session_instance
|
||||
sio = get_mock_sio()
|
||||
get_running_agent_loops_mock = AsyncMock()
|
||||
get_running_agent_loops_mock.return_value = set()
|
||||
with (
|
||||
patch(
|
||||
'openhands.server.conversation_manager.standalone_conversation_manager.Session',
|
||||
mock_session,
|
||||
),
|
||||
patch(
|
||||
'openhands.server.conversation_manager.standalone_conversation_manager.StandaloneConversationManager.get_running_agent_loops',
|
||||
get_running_agent_loops_mock,
|
||||
),
|
||||
):
|
||||
async with StandaloneConversationManager(
|
||||
sio, AppConfig(), InMemoryFileStore()
|
||||
) as conversation_manager:
|
||||
await conversation_manager.maybe_start_agent_loop(
|
||||
'new-session-id', ConversationInitData(), None
|
||||
)
|
||||
await conversation_manager.join_conversation(
|
||||
'new-session-id', 'new-session-id', ConversationInitData(), None
|
||||
)
|
||||
await conversation_manager.join_conversation(
|
||||
'new-session-id', 'new-session-id', ConversationInitData(), None
|
||||
)
|
||||
assert session_instance.initialize_agent.call_count == 1
|
||||
assert sio.enter_room.await_count == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_to_local_event_stream():
|
||||
session_instance = AsyncMock()
|
||||
session_instance.agent_session = MagicMock()
|
||||
mock_session = MagicMock()
|
||||
mock_session.return_value = session_instance
|
||||
sio = get_mock_sio()
|
||||
get_running_agent_loops_mock = AsyncMock()
|
||||
get_running_agent_loops_mock.return_value = set()
|
||||
with (
|
||||
patch(
|
||||
'openhands.server.conversation_manager.standalone_conversation_manager.Session',
|
||||
mock_session,
|
||||
),
|
||||
patch(
|
||||
'openhands.server.conversation_manager.standalone_conversation_manager.StandaloneConversationManager.get_running_agent_loops',
|
||||
get_running_agent_loops_mock,
|
||||
),
|
||||
):
|
||||
async with StandaloneConversationManager(
|
||||
sio, AppConfig(), InMemoryFileStore()
|
||||
) as conversation_manager:
|
||||
await conversation_manager.maybe_start_agent_loop(
|
||||
'new-session-id', ConversationInitData(), 1
|
||||
)
|
||||
await conversation_manager.join_conversation(
|
||||
'new-session-id', 'connection-id', ConversationInitData(), 1
|
||||
)
|
||||
await conversation_manager.send_to_event_stream(
|
||||
'connection-id', {'event_type': 'some_event'}
|
||||
)
|
||||
session_instance.dispatch.assert_called_once_with({'event_type': 'some_event'})
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cleanup_session_connections():
|
||||
sio = get_mock_sio()
|
||||
async with StandaloneConversationManager(
|
||||
sio, AppConfig(), InMemoryFileStore()
|
||||
) as conversation_manager:
|
||||
conversation_manager._local_connection_id_to_session_id.update(
|
||||
{
|
||||
'conn1': 'session1',
|
||||
'conn2': 'session1',
|
||||
'conn3': 'session2',
|
||||
'conn4': 'session2',
|
||||
}
|
||||
)
|
||||
|
||||
await conversation_manager._close_session('session1')
|
||||
|
||||
remaining_connections = conversation_manager._local_connection_id_to_session_id
|
||||
assert 'conn1' not in remaining_connections
|
||||
assert 'conn2' not in remaining_connections
|
||||
assert 'conn3' in remaining_connections
|
||||
assert 'conn4' in remaining_connections
|
||||
assert remaining_connections['conn3'] == 'session2'
|
||||
assert remaining_connections['conn4'] == 'session2'
|
||||
Reference in New Issue
Block a user