Compare commits

...

1 Commits

Author SHA1 Message Date
Reinier van der Leer
8b0579a87c feat(benchmark): Add -P, --parallel-tasks option to allow running multiple tasks concurrently
* Add dependency `pytest-parallel` and indirect dependency `py` (pylib)
* Make `SingletonReportManager` thread safe
2024-01-29 11:33:42 +01:00
5 changed files with 79 additions and 23 deletions

View File

@@ -62,6 +62,9 @@ def start():
@click.option(
"-N", "--attempts", default=1, help="Number of times to run each challenge."
)
@click.option(
"-P", "--parallel-tasks", default=1, help="Number of challenges to run in parallel."
)
@click.option(
"-c",
"--category",
@@ -111,6 +114,7 @@ def run(
category: tuple[str],
skip_category: tuple[str],
attempts: int,
parallel_tasks: int,
cutoff: Optional[int] = None,
backend: Optional[bool] = False,
# agent_path: Optional[Path] = None,
@@ -158,6 +162,7 @@ def run(
categories=category,
skip_categories=skip_category,
attempts_per_challenge=attempts,
concurrent_tasks=parallel_tasks,
cutoff=cutoff,
)
@@ -177,6 +182,7 @@ def run(
categories=category,
skip_categories=skip_category,
attempts_per_challenge=attempts,
concurrent_tasks=parallel_tasks,
cutoff=cutoff,
)

View File

@@ -22,6 +22,7 @@ def run_benchmark(
categories: tuple[str] = tuple(),
skip_categories: tuple[str] = tuple(),
attempts_per_challenge: int = 1,
concurrent_tasks: int = 1,
mock: bool = False,
no_dep: bool = False,
no_cutoff: bool = False,
@@ -100,6 +101,9 @@ def run_benchmark(
if attempts_per_challenge > 1:
pytest_args.append(f"--attempts={attempts_per_challenge}")
if concurrent_tasks > 1:
pytest_args.append(f"--tests-per-worker={concurrent_tasks}")
if cutoff:
pytest_args.append(f"--cutoff={cutoff}")
logger.debug(f"Setting cuttoff override to {cutoff} seconds.")

View File

@@ -3,10 +3,11 @@ import json
import logging
import os
import sys
import threading
import time
from datetime import datetime, timezone
from pathlib import Path
from typing import Any
from typing import Any, ClassVar
from agbenchmark.config import AgentBenchmarkConfig
from agbenchmark.reports.processing.graphs import save_single_radar_chart
@@ -20,39 +21,39 @@ logger = logging.getLogger(__name__)
class SingletonReportManager:
instance = None
_instance = None
_lock: ClassVar[threading.Lock] = threading.Lock()
INFO_MANAGER: "SessionReportManager"
REGRESSION_MANAGER: "RegressionTestsTracker"
SUCCESS_RATE_TRACKER: "SuccessRatesTracker"
def __new__(cls):
if not cls.instance:
cls.instance = super(SingletonReportManager, cls).__new__(cls)
with cls._lock:
if not cls._instance:
cls._instance = super(SingletonReportManager, cls).__new__(cls)
agent_benchmark_config = AgentBenchmarkConfig.load()
benchmark_start_time_dt = datetime.now(
timezone.utc
) # or any logic to fetch the datetime
agent_benchmark_config = AgentBenchmarkConfig.load()
benchmark_start_time_dt = datetime.now(timezone.utc)
# Make the Managers class attributes
cls.INFO_MANAGER = SessionReportManager(
agent_benchmark_config.get_report_dir(benchmark_start_time_dt)
/ "report.json",
benchmark_start_time_dt,
)
cls.REGRESSION_MANAGER = RegressionTestsTracker(
agent_benchmark_config.regression_tests_file
)
cls.SUCCESS_RATE_TRACKER = SuccessRatesTracker(
agent_benchmark_config.success_rate_file
)
# Make the Managers class attributes
cls.INFO_MANAGER = SessionReportManager(
agent_benchmark_config.get_report_dir(benchmark_start_time_dt)
/ "report.json",
benchmark_start_time_dt,
)
cls.REGRESSION_MANAGER = RegressionTestsTracker(
agent_benchmark_config.regression_tests_file
)
cls.SUCCESS_RATE_TRACKER = SuccessRatesTracker(
agent_benchmark_config.success_rate_file
)
return cls.instance
return cls._instance
@classmethod
def clear_instance(cls):
cls.instance = None
cls._instance = None
cls.INFO_MANAGER = None
cls.REGRESSION_MANAGER = None
cls.SUCCESS_RATE_TRACKER = None
@@ -131,6 +132,12 @@ class SessionReportManager(BaseReportManager):
self.save()
def get_test_report(self, test_name: str) -> Test | None:
if isinstance(self.tests, Report):
return self.tests.tests.get(test_name)
else:
return self.tests.get(test_name)
def finalize_session_report(self, config: AgentBenchmarkConfig) -> None:
command = " ".join(sys.argv)

39
benchmark/poetry.lock generated
View File

@@ -1946,6 +1946,17 @@ files = [
[package.extras]
tests = ["pytest"]
[[package]]
name = "py"
version = "1.11.0"
description = "library with cross-python path, ini-parsing, io, code, log facilities"
optional = false
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*"
files = [
{file = "py-1.11.0-py2.py3-none-any.whl", hash = "sha256:607c53218732647dff4acdfcd50cb62615cedf612e72d1724fb1a0cc6405b378"},
{file = "py-1.11.0.tar.gz", hash = "sha256:51c75c4126074b472f746a24399ad32f6053d1b34b68d2fa41e558e6f4a98719"},
]
[[package]]
name = "pyasn1"
version = "0.5.1"
@@ -2137,6 +2148,21 @@ pytest = ">=7.0.0"
docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1.0)"]
testing = ["coverage (>=6.2)", "flaky (>=3.5.0)", "hypothesis (>=5.7.1)", "mypy (>=0.931)", "pytest-trio (>=0.7.0)"]
[[package]]
name = "pytest-parallel"
version = "0.1.1"
description = "a pytest plugin for parallel and concurrent testing"
optional = false
python-versions = "*"
files = [
{file = "pytest-parallel-0.1.1.tar.gz", hash = "sha256:9aac3fc199a168c0a8559b60249d9eb254de7af58c12cee0310b54d4affdbfab"},
{file = "pytest_parallel-0.1.1-py3-none-any.whl", hash = "sha256:9e3703015b0eda52be9e07d2ba3498f09340a56d5c79a39b50f22fc5c38212fe"},
]
[package.dependencies]
pytest = ">=3.0.0"
tblib = "*"
[[package]]
name = "python-dateutil"
version = "2.8.2"
@@ -2431,6 +2457,17 @@ anyio = ">=3.4.0,<5"
[package.extras]
full = ["httpx (>=0.22.0)", "itsdangerous", "jinja2", "python-multipart", "pyyaml"]
[[package]]
name = "tblib"
version = "3.0.0"
description = "Traceback serialization library."
optional = false
python-versions = ">=3.8"
files = [
{file = "tblib-3.0.0-py3-none-any.whl", hash = "sha256:80a6c77e59b55e83911e1e607c649836a69c103963c5f28a46cbeef44acf8129"},
{file = "tblib-3.0.0.tar.gz", hash = "sha256:93622790a0a29e04f0346458face1e144dc4d32f493714c6c3dff82a4adb77e6"},
]
[[package]]
name = "toml"
version = "0.10.2"
@@ -2760,4 +2797,4 @@ multidict = ">=4.0"
[metadata]
lock-version = "2.0"
python-versions = "^3.10"
content-hash = "e0d1f991958a5d630287c7bb668e7fdc6183630e06196cf6f507a086be10baec"
content-hash = "4a4e53f252c8996b172bbb35a730197c07c53d7b50bf1d21964d3b2237495066"

View File

@@ -25,7 +25,9 @@ networkx = "^3.1"
colorama = "^0.4.6"
pyvis = "^0.3.2"
selenium = "^4.11.2"
py = "^1.11.0" # needed for pytest-parallel
pytest-asyncio = "^0.21.1"
pytest-parallel = "^0.1.1"
uvicorn = "^0.23.2"
fastapi = "^0.99.0"
python-multipart = "^0.0.6"