Compare commits

...

1 Commits

Author SHA1 Message Date
psychedelicious
317cf01db4 experiment: memory tracker via context manager 2025-05-13 15:49:04 +10:00
2 changed files with 127 additions and 36 deletions

View File

@@ -2,9 +2,12 @@
from __future__ import annotations
import contextlib
import inspect
import os
import re
import sys
import time
import warnings
from abc import ABC, abstractmethod
from enum import Enum
@@ -24,7 +27,9 @@ from typing import (
Union,
)
import psutil
import semver
import torch
from pydantic import BaseModel, ConfigDict, Field, JsonValue, TypeAdapter, create_model
from pydantic.fields import FieldInfo
from pydantic_core import PydanticUndefined
@@ -140,6 +145,86 @@ class BaseInvocationOutput(BaseModel):
)
try:
import resource # Unix
except ImportError:
resource = None
class MemoryTracker(contextlib.AbstractContextManager["MemoryTracker"]):
"""
Context Manager that records start/end RSS and peak RSS (window-specific, via baseline+delta) and, if CUDA is
available, start/end/peak VRAM (via torch.cuda).
"""
def __init__(self):
self.torch_device = torch.device("cuda") if torch.cuda.is_available() else None
def __enter__(self):
self._proc = psutil.Process(os.getpid())
self.start_time = time.time()
self.start_rss: int = self._proc.memory_info().rss
if sys.platform == "win32":
self._base_peak = self._proc.memory_info().peak_wset
elif resource:
ru = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss
# On Linux, ru_maxrss is reported in kilobytes. On macOS, it is reported in bytes. We want bytes.
self._base_peak = ru if sys.platform == "darwin" else ru * 1024
else:
self._base_peak = self.start_rss
if self.torch_device:
torch.cuda.reset_peak_memory_stats(self.torch_device)
self.start_vram = torch.cuda.memory_allocated(self.torch_device)
return self
def __exit__(self, exc_type, exc_val, exc_tb):
# end RSS
self.end_time = time.time()
self.end_rss: int = self._proc.memory_info().rss
# windowpeak RSS = max(0, end_peak baseline_peak)
if sys.platform == "win32":
end_peak = self._proc.memory_info().peak_wset
elif resource:
ru = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss
end_peak = ru if sys.platform == "darwin" else ru * 1024
else:
end_peak = self.end_rss
self.peak_rss = max(0, end_peak - self._base_peak)
# end & peak VRAM via torch
if self.torch_device:
self.end_vram = torch.cuda.memory_allocated(self.torch_device)
peak_alloc = torch.cuda.max_memory_allocated(self.torch_device)
self.peak_vram = max(0, peak_alloc - self.start_vram)
return False # don't suppress exceptions
@property
def stats(self):
out = {
"start_time": self.start_time,
"end_time": self.end_time,
"elapsed": self.end_time - self.start_time,
"start_rss": self.start_rss,
"end_rss": self.end_rss,
"peak_rss": self.peak_rss,
}
if self.torch_device:
out.update(
{
"start_vram": self.start_vram,
"end_vram": self.end_vram,
"peak_vram": self.peak_vram,
}
)
return out
class RequiredConnectionException(Exception):
"""Raised when an field which requires a connection did not receive a value."""
@@ -192,49 +277,55 @@ class BaseInvocation(ABC, BaseModel):
"""Invoke with provided context and return outputs."""
pass
def invoke_internal(self, context: InvocationContext, services: "InvocationServices") -> BaseInvocationOutput:
def invoke_internal(self, context: InvocationContext, services: "InvocationServices") -> tuple[BaseInvocationOutput, MemoryTracker]:
"""
Internal invoke method, calls `invoke()` after some prep.
Handles optional fields that are required to call `invoke()` and invocation cache.
"""
for field_name, field in self.model_fields.items():
if not field.json_schema_extra or callable(field.json_schema_extra):
# something has gone terribly awry, we should always have this and it should be a dict
continue
# Here we handle the case where the field is optional in the pydantic class, but required
# in the `invoke()` method.
orig_default = field.json_schema_extra.get("orig_default", PydanticUndefined)
orig_required = field.json_schema_extra.get("orig_required", True)
input_ = field.json_schema_extra.get("input", None)
if orig_default is not PydanticUndefined and not hasattr(self, field_name):
setattr(self, field_name, orig_default)
if orig_required and orig_default is PydanticUndefined and getattr(self, field_name) is None:
if input_ == Input.Connection:
raise RequiredConnectionException(self.model_fields["type"].default, field_name)
elif input_ == Input.Any:
raise MissingInputException(self.model_fields["type"].default, field_name)
# skip node cache codepath if it's disabled
if services.configuration.node_cache_size == 0:
return self.invoke(context)
output: BaseInvocationOutput
if self.use_cache:
key = services.invocation_cache.create_key(self)
cached_value = services.invocation_cache.get(key)
if cached_value is None:
services.logger.debug(f'Invocation cache miss for type "{self.get_type()}": {self.id}')
with MemoryTracker() as memory_tracker:
for field_name, field in self.model_fields.items():
if not field.json_schema_extra or callable(field.json_schema_extra):
# something has gone terribly awry, we should always have this and it should be a dict
continue
# Here we handle the case where the field is optional in the pydantic class, but required
# in the `invoke()` method.
orig_default = field.json_schema_extra.get("orig_default", PydanticUndefined)
orig_required = field.json_schema_extra.get("orig_required", True)
input_ = field.json_schema_extra.get("input", None)
if orig_default is not PydanticUndefined and not hasattr(self, field_name):
setattr(self, field_name, orig_default)
if orig_required and orig_default is PydanticUndefined and getattr(self, field_name) is None:
if input_ == Input.Connection:
raise RequiredConnectionException(self.model_fields["type"].default, field_name)
elif input_ == Input.Any:
raise MissingInputException(self.model_fields["type"].default, field_name)
# skip node cache codepath if it's disabled
if services.configuration.node_cache_size == 0:
output = self.invoke(context)
services.invocation_cache.save(key, output)
return output
if self.use_cache:
key = services.invocation_cache.create_key(self)
cached_value = services.invocation_cache.get(key)
if cached_value is None:
services.logger.debug(f'Invocation cache miss for type "{self.get_type()}": {self.id}')
output = self.invoke(context)
services.invocation_cache.save(key, output)
else:
services.logger.debug(f'Invocation cache hit for type "{self.get_type()}": {self.id}')
output = cached_value
else:
services.logger.debug(f'Invocation cache hit for type "{self.get_type()}": {self.id}')
return cached_value
else:
services.logger.debug(f'Skipping invocation cache for "{self.get_type()}": {self.id}')
return self.invoke(context)
services.logger.debug(f'Skipping invocation cache for "{self.get_type()}": {self.id}')
output = self.invoke(context)
services.logger.info(memory_tracker.stats)
return (output, memory_tracker)
id: str = Field(
default_factory=uuid_string,

View File

@@ -126,7 +126,7 @@ class DefaultSessionRunner(SessionRunnerBase):
)
# Invoke the node
output = invocation.invoke_internal(context=context, services=self._services)
(output, memory_tracker) = invocation.invoke_internal(context=context, services=self._services)
# Save output and history
queue_item.session.complete(invocation.id, output)