mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-01-10 23:58:06 -05:00
refactor(benchmark): Simplify models in report_types.py
- Removed ForbidOptionalMeta and BaseModelBenchmark classes. - Changed model attributes to optional: `Metrics.difficulty`, `Metrics.success`, `Metrics.success_percentage`, `Metrics.run_time`, and `Test.reached_cutoff`. - Added validator to `Metrics` model to require `success` and `run_time` fields if `attempted=True`. - Added default values to all optional model fields. - Removed duplicate imports. - Added condition in process_report.py to prevent null lookups if `metrics.difficulty` is not set.
This commit is contained in:
@@ -46,7 +46,7 @@ def get_agent_category(report: Report) -> dict[str, Any]:
|
||||
):
|
||||
continue
|
||||
categories.setdefault(category, 0)
|
||||
if data.metrics.success:
|
||||
if data.metrics.success and data.metrics.difficulty:
|
||||
num_dif = STRING_DIFFICULTY_MAP[data.metrics.difficulty]
|
||||
if num_dif > categories[category]:
|
||||
categories[category] = num_dif
|
||||
|
||||
@@ -1,48 +1,38 @@
|
||||
from typing import Any, Dict, List, Union
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, Field, constr, validator
|
||||
|
||||
datetime_format = r"^\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}\+00:00$"
|
||||
from pydantic import BaseModel, constr
|
||||
|
||||
|
||||
class ForbidOptionalMeta(type(BaseModel)): # metaclass to forbid optional fields
|
||||
def __new__(cls, name: str, bases: tuple, dct: Dict[str, Any]) -> Any:
|
||||
for attr_name, attr_value in dct.items():
|
||||
if (
|
||||
getattr(attr_value, "__origin__", None) == Union
|
||||
and type(None) in attr_value.__args__
|
||||
):
|
||||
raise TypeError(
|
||||
f"Optional fields are forbidden, but found in {attr_name}"
|
||||
)
|
||||
|
||||
return super().__new__(cls, name, bases, dct)
|
||||
|
||||
|
||||
class BaseModelBenchmark(BaseModel, metaclass=ForbidOptionalMeta):
|
||||
class Config:
|
||||
extra = "forbid"
|
||||
|
||||
|
||||
class Metrics(BaseModelBenchmark):
|
||||
difficulty: str
|
||||
success: bool
|
||||
success_percentage: float = Field(..., alias="success_%")
|
||||
run_time: str
|
||||
fail_reason: str | None
|
||||
class Metrics(BaseModel):
|
||||
difficulty: str | None
|
||||
success: bool | None = None
|
||||
run_time: str | None = None
|
||||
fail_reason: str | None = None
|
||||
success_percentage: float | None = Field(default=None, alias="success_%")
|
||||
attempted: bool
|
||||
cost: float | None
|
||||
cost: float | None = None
|
||||
|
||||
@validator("attempted")
|
||||
def require_metrics_if_attempted(cls, v: bool, values: dict[str, Any]):
|
||||
required_fields_if_attempted = ["success", "run_time"]
|
||||
if v:
|
||||
for f in required_fields_if_attempted:
|
||||
assert (
|
||||
values.get(f) is not None
|
||||
), f"'{f}' must be defined if attempted is True"
|
||||
return v
|
||||
|
||||
|
||||
class MetricsOverall(BaseModelBenchmark):
|
||||
class MetricsOverall(BaseModel):
|
||||
run_time: str
|
||||
highest_difficulty: str
|
||||
percentage: float | None
|
||||
total_cost: float | None
|
||||
percentage: float | None = None
|
||||
total_cost: float | None = None
|
||||
|
||||
|
||||
class Test(BaseModelBenchmark):
|
||||
class Test(BaseModel):
|
||||
data_path: str
|
||||
is_regression: bool
|
||||
answer: str
|
||||
@@ -50,19 +40,19 @@ class Test(BaseModelBenchmark):
|
||||
metrics: Metrics
|
||||
category: List[str]
|
||||
task: str
|
||||
reached_cutoff: bool
|
||||
metadata: Any
|
||||
reached_cutoff: bool | None = None # None if in progress
|
||||
metadata: dict[str, Any] | None = Field(default_factory=dict)
|
||||
|
||||
|
||||
class ReportBase(BaseModelBenchmark):
|
||||
class ReportBase(BaseModel):
|
||||
command: str
|
||||
completion_time: str | None
|
||||
completion_time: str | None = None
|
||||
benchmark_start_time: constr(regex=datetime_format)
|
||||
metrics: MetricsOverall
|
||||
config: Dict[str, str | dict[str, str]]
|
||||
agent_git_commit_sha: str | None
|
||||
benchmark_git_commit_sha: str | None
|
||||
repo_url: str | None
|
||||
agent_git_commit_sha: str | None = None
|
||||
benchmark_git_commit_sha: str | None = None
|
||||
repo_url: str | None = None
|
||||
|
||||
|
||||
class Report(ReportBase):
|
||||
|
||||
Reference in New Issue
Block a user