refactor(benchmark): Simplify models in report_types.py

- Removed ForbidOptionalMeta and BaseModelBenchmark classes.
- Changed model attributes to optional: `Metrics.difficulty`, `Metrics.success`, `Metrics.success_percentage`, `Metrics.run_time`, and `Test.reached_cutoff`.
- Added validator to `Metrics` model to require `success` and `run_time` fields if `attempted=True`.
- Added default values to all optional model fields.
- Removed duplicate imports.
- Added condition in process_report.py to prevent null lookups if `metrics.difficulty` is not set.
This commit is contained in:
Reinier van der Leer
2024-01-09 15:16:43 +01:00
parent 25cc6ad6ae
commit 370d6dbf5d
2 changed files with 31 additions and 41 deletions

View File

@@ -46,7 +46,7 @@ def get_agent_category(report: Report) -> dict[str, Any]:
):
continue
categories.setdefault(category, 0)
if data.metrics.success:
if data.metrics.success and data.metrics.difficulty:
num_dif = STRING_DIFFICULTY_MAP[data.metrics.difficulty]
if num_dif > categories[category]:
categories[category] = num_dif

View File

@@ -1,48 +1,38 @@
from typing import Any, Dict, List, Union
from typing import Any, Dict, List
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, constr, validator
datetime_format = r"^\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}\+00:00$"
from pydantic import BaseModel, constr
class ForbidOptionalMeta(type(BaseModel)): # metaclass to forbid optional fields
def __new__(cls, name: str, bases: tuple, dct: Dict[str, Any]) -> Any:
for attr_name, attr_value in dct.items():
if (
getattr(attr_value, "__origin__", None) == Union
and type(None) in attr_value.__args__
):
raise TypeError(
f"Optional fields are forbidden, but found in {attr_name}"
)
return super().__new__(cls, name, bases, dct)
class BaseModelBenchmark(BaseModel, metaclass=ForbidOptionalMeta):
class Config:
extra = "forbid"
class Metrics(BaseModelBenchmark):
difficulty: str
success: bool
success_percentage: float = Field(..., alias="success_%")
run_time: str
fail_reason: str | None
class Metrics(BaseModel):
difficulty: str | None
success: bool | None = None
run_time: str | None = None
fail_reason: str | None = None
success_percentage: float | None = Field(default=None, alias="success_%")
attempted: bool
cost: float | None
cost: float | None = None
@validator("attempted")
def require_metrics_if_attempted(cls, v: bool, values: dict[str, Any]):
required_fields_if_attempted = ["success", "run_time"]
if v:
for f in required_fields_if_attempted:
assert (
values.get(f) is not None
), f"'{f}' must be defined if attempted is True"
return v
class MetricsOverall(BaseModelBenchmark):
class MetricsOverall(BaseModel):
run_time: str
highest_difficulty: str
percentage: float | None
total_cost: float | None
percentage: float | None = None
total_cost: float | None = None
class Test(BaseModelBenchmark):
class Test(BaseModel):
data_path: str
is_regression: bool
answer: str
@@ -50,19 +40,19 @@ class Test(BaseModelBenchmark):
metrics: Metrics
category: List[str]
task: str
reached_cutoff: bool
metadata: Any
reached_cutoff: bool | None = None # None if in progress
metadata: dict[str, Any] | None = Field(default_factory=dict)
class ReportBase(BaseModelBenchmark):
class ReportBase(BaseModel):
command: str
completion_time: str | None
completion_time: str | None = None
benchmark_start_time: constr(regex=datetime_format)
metrics: MetricsOverall
config: Dict[str, str | dict[str, str]]
agent_git_commit_sha: str | None
benchmark_git_commit_sha: str | None
repo_url: str | None
agent_git_commit_sha: str | None = None
benchmark_git_commit_sha: str | None = None
repo_url: str | None = None
class Report(ReportBase):