Files
InvokeAI/invokeai/app/services/invocation_stats/invocation_stats_common.py
psychedelicious b24e8dd829 feat(stats): refactor InvocationStatsService to output stats as dataclasses
This allows the stats to be written to disk as JSON and analyzed.

- Add dataclasses to hold stats.
- Move stats pretty-print logic to `__str__` of the new `InvocationStatsSummary` class.
- Add `get_stats` and `dump_stats` methods to `InvocationStatsServiceBase`.
- `InvocationStatsService` now throws if stats are requested for a session it doesn't know about. This avoids needing to do a lot of messy null checks.
- Update `DefaultInvocationProcessor` to use the new stats methods and suppresses the new errors.
2024-02-01 08:50:56 +11:00

184 lines
6.6 KiB
Python

from collections import defaultdict
from dataclasses import asdict, dataclass
from typing import Any, Optional
class GESStatsNotFoundError(Exception):
"""Raised when execution stats are not found for a given Graph Execution State."""
@dataclass
class NodeExecutionStatsSummary:
"""The stats for a specific type of node."""
node_type: str
num_calls: int
time_used_seconds: float
peak_vram_gb: float
@dataclass
class ModelCacheStatsSummary:
"""The stats for the model cache."""
high_water_mark_gb: float
cache_size_gb: float
total_usage_gb: float
cache_hits: int
cache_misses: int
models_cached: int
models_cleared: int
@dataclass
class GraphExecutionStatsSummary:
"""The stats for the graph execution state."""
graph_execution_state_id: str
execution_time_seconds: float
# `wall_time_seconds`, `ram_usage_gb` and `ram_change_gb` are derived from the node execution stats.
# In some situations, there are no node stats, so these values are optional.
wall_time_seconds: Optional[float]
ram_usage_gb: Optional[float]
ram_change_gb: Optional[float]
@dataclass
class InvocationStatsSummary:
"""
The accumulated stats for a graph execution.
Its `__str__` method returns a human-readable stats summary.
"""
vram_usage_gb: Optional[float]
graph_stats: GraphExecutionStatsSummary
model_cache_stats: ModelCacheStatsSummary
node_stats: list[NodeExecutionStatsSummary]
def __str__(self) -> str:
_str = ""
_str = f"Graph stats: {self.graph_stats.graph_execution_state_id}\n"
_str += f"{'Node':>30} {'Calls':>7} {'Seconds':>9} {'VRAM Used':>10}\n"
for summary in self.node_stats:
_str += f"{summary.node_type:>30} {summary.num_calls:>7} {summary.time_used_seconds:>8.3f}s {summary.peak_vram_gb:>9.3f}G\n"
_str += f"TOTAL GRAPH EXECUTION TIME: {self.graph_stats.execution_time_seconds:7.3f}s\n"
if self.graph_stats.wall_time_seconds is not None:
_str += f"TOTAL GRAPH WALL TIME: {self.graph_stats.wall_time_seconds:7.3f}s\n"
if self.graph_stats.ram_usage_gb is not None and self.graph_stats.ram_change_gb is not None:
_str += f"RAM used by InvokeAI process: {self.graph_stats.ram_usage_gb:4.2f}G ({self.graph_stats.ram_change_gb:+5.3f}G)\n"
_str += f"RAM used to load models: {self.model_cache_stats.total_usage_gb:4.2f}G\n"
if self.vram_usage_gb:
_str += f"VRAM in use: {self.vram_usage_gb:4.3f}G\n"
_str += "RAM cache statistics:\n"
_str += f" Model cache hits: {self.model_cache_stats.cache_hits}\n"
_str += f" Model cache misses: {self.model_cache_stats.cache_misses}\n"
_str += f" Models cached: {self.model_cache_stats.models_cached}\n"
_str += f" Models cleared from cache: {self.model_cache_stats.models_cleared}\n"
_str += f" Cache high water mark: {self.model_cache_stats.high_water_mark_gb:4.2f}/{self.model_cache_stats.cache_size_gb:4.2f}G\n"
return _str
def as_dict(self) -> dict[str, Any]:
"""Returns the stats as a dictionary."""
return asdict(self)
@dataclass
class NodeExecutionStats:
"""Class for tracking execution stats of an invocation node."""
invocation_type: str
start_time: float # Seconds since the epoch.
end_time: float # Seconds since the epoch.
start_ram_gb: float # GB
end_ram_gb: float # GB
peak_vram_gb: float # GB
def total_time(self) -> float:
return self.end_time - self.start_time
class GraphExecutionStats:
"""Class for tracking execution stats of a graph."""
def __init__(self):
self._node_stats_list: list[NodeExecutionStats] = []
def add_node_execution_stats(self, node_stats: NodeExecutionStats):
self._node_stats_list.append(node_stats)
def get_total_run_time(self) -> float:
"""Get the total time spent executing nodes in the graph."""
total = 0.0
for node_stats in self._node_stats_list:
total += node_stats.total_time()
return total
def get_first_node_stats(self) -> NodeExecutionStats | None:
"""Get the stats of the first node in the graph (by start_time)."""
first_node = None
for node_stats in self._node_stats_list:
if first_node is None or node_stats.start_time < first_node.start_time:
first_node = node_stats
assert first_node is not None
return first_node
def get_last_node_stats(self) -> NodeExecutionStats | None:
"""Get the stats of the last node in the graph (by end_time)."""
last_node = None
for node_stats in self._node_stats_list:
if last_node is None or node_stats.end_time > last_node.end_time:
last_node = node_stats
return last_node
def get_graph_stats_summary(self, graph_execution_state_id: str) -> GraphExecutionStatsSummary:
"""Get a summary of the graph stats."""
first_node = self.get_first_node_stats()
last_node = self.get_last_node_stats()
wall_time_seconds: Optional[float] = None
ram_usage_gb: Optional[float] = None
ram_change_gb: Optional[float] = None
if last_node and first_node:
wall_time_seconds = last_node.end_time - first_node.start_time
ram_usage_gb = last_node.end_ram_gb
ram_change_gb = last_node.end_ram_gb - first_node.start_ram_gb
return GraphExecutionStatsSummary(
graph_execution_state_id=graph_execution_state_id,
execution_time_seconds=self.get_total_run_time(),
wall_time_seconds=wall_time_seconds,
ram_usage_gb=ram_usage_gb,
ram_change_gb=ram_change_gb,
)
def get_node_stats_summaries(self) -> list[NodeExecutionStatsSummary]:
"""Get a summary of the node stats."""
summaries: list[NodeExecutionStatsSummary] = []
node_stats_by_type: dict[str, list[NodeExecutionStats]] = defaultdict(list)
for node_stats in self._node_stats_list:
node_stats_by_type[node_stats.invocation_type].append(node_stats)
for node_type, node_type_stats_list in node_stats_by_type.items():
num_calls = len(node_type_stats_list)
time_used = sum([n.total_time() for n in node_type_stats_list])
peak_vram = max([n.peak_vram_gb for n in node_type_stats_list])
summary = NodeExecutionStatsSummary(
node_type=node_type, num_calls=num_calls, time_used_seconds=time_used, peak_vram_gb=peak_vram
)
summaries.append(summary)
return summaries