diff --git a/benchmark/agbenchmark/reports/processing/process_report.py b/benchmark/agbenchmark/reports/processing/process_report.py index 1f73ed3c0b..57a2ee4fb2 100644 --- a/benchmark/agbenchmark/reports/processing/process_report.py +++ b/benchmark/agbenchmark/reports/processing/process_report.py @@ -46,7 +46,7 @@ def get_agent_category(report: Report) -> dict[str, Any]: ): continue categories.setdefault(category, 0) - if data.metrics.success: + if data.metrics.success and data.metrics.difficulty: num_dif = STRING_DIFFICULTY_MAP[data.metrics.difficulty] if num_dif > categories[category]: categories[category] = num_dif diff --git a/benchmark/agbenchmark/reports/processing/report_types.py b/benchmark/agbenchmark/reports/processing/report_types.py index 3ba9e6c6bf..e462ce2811 100644 --- a/benchmark/agbenchmark/reports/processing/report_types.py +++ b/benchmark/agbenchmark/reports/processing/report_types.py @@ -1,48 +1,38 @@ -from typing import Any, Dict, List, Union +from typing import Any, Dict, List -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, constr, validator datetime_format = r"^\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}\+00:00$" -from pydantic import BaseModel, constr -class ForbidOptionalMeta(type(BaseModel)): # metaclass to forbid optional fields - def __new__(cls, name: str, bases: tuple, dct: Dict[str, Any]) -> Any: - for attr_name, attr_value in dct.items(): - if ( - getattr(attr_value, "__origin__", None) == Union - and type(None) in attr_value.__args__ - ): - raise TypeError( - f"Optional fields are forbidden, but found in {attr_name}" - ) - - return super().__new__(cls, name, bases, dct) - - -class BaseModelBenchmark(BaseModel, metaclass=ForbidOptionalMeta): - class Config: - extra = "forbid" - - -class Metrics(BaseModelBenchmark): - difficulty: str - success: bool - success_percentage: float = Field(..., alias="success_%") - run_time: str - fail_reason: str | None +class Metrics(BaseModel): + difficulty: str | None + success: bool | None = None + run_time: str | None = None + fail_reason: str | None = None + success_percentage: float | None = Field(default=None, alias="success_%") attempted: bool - cost: float | None + cost: float | None = None + + @validator("attempted") + def require_metrics_if_attempted(cls, v: bool, values: dict[str, Any]): + required_fields_if_attempted = ["success", "run_time"] + if v: + for f in required_fields_if_attempted: + assert ( + values.get(f) is not None + ), f"'{f}' must be defined if attempted is True" + return v -class MetricsOverall(BaseModelBenchmark): +class MetricsOverall(BaseModel): run_time: str highest_difficulty: str - percentage: float | None - total_cost: float | None + percentage: float | None = None + total_cost: float | None = None -class Test(BaseModelBenchmark): +class Test(BaseModel): data_path: str is_regression: bool answer: str @@ -50,19 +40,19 @@ class Test(BaseModelBenchmark): metrics: Metrics category: List[str] task: str - reached_cutoff: bool - metadata: Any + reached_cutoff: bool | None = None # None if in progress + metadata: dict[str, Any] | None = Field(default_factory=dict) -class ReportBase(BaseModelBenchmark): +class ReportBase(BaseModel): command: str - completion_time: str | None + completion_time: str | None = None benchmark_start_time: constr(regex=datetime_format) metrics: MetricsOverall config: Dict[str, str | dict[str, str]] - agent_git_commit_sha: str | None - benchmark_git_commit_sha: str | None - repo_url: str | None + agent_git_commit_sha: str | None = None + benchmark_git_commit_sha: str | None = None + repo_url: str | None = None class Report(ReportBase):