mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-01-16 14:58:03 -05:00
Compare commits
1 Commits
controlnet
...
psyche/fea
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
317cf01db4 |
@@ -2,9 +2,12 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import inspect
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
import time
|
||||
import warnings
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
@@ -24,7 +27,9 @@ from typing import (
|
||||
Union,
|
||||
)
|
||||
|
||||
import psutil
|
||||
import semver
|
||||
import torch
|
||||
from pydantic import BaseModel, ConfigDict, Field, JsonValue, TypeAdapter, create_model
|
||||
from pydantic.fields import FieldInfo
|
||||
from pydantic_core import PydanticUndefined
|
||||
@@ -140,6 +145,86 @@ class BaseInvocationOutput(BaseModel):
|
||||
)
|
||||
|
||||
|
||||
try:
|
||||
import resource # Unix
|
||||
except ImportError:
|
||||
resource = None
|
||||
|
||||
|
||||
class MemoryTracker(contextlib.AbstractContextManager["MemoryTracker"]):
|
||||
"""
|
||||
Context Manager that records start/end RSS and peak RSS (window-specific, via baseline+delta) and, if CUDA is
|
||||
available, start/end/peak VRAM (via torch.cuda).
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.torch_device = torch.device("cuda") if torch.cuda.is_available() else None
|
||||
|
||||
def __enter__(self):
|
||||
self._proc = psutil.Process(os.getpid())
|
||||
self.start_time = time.time()
|
||||
self.start_rss: int = self._proc.memory_info().rss
|
||||
|
||||
if sys.platform == "win32":
|
||||
self._base_peak = self._proc.memory_info().peak_wset
|
||||
elif resource:
|
||||
ru = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss
|
||||
# On Linux, ru_maxrss is reported in kilobytes. On macOS, it is reported in bytes. We want bytes.
|
||||
self._base_peak = ru if sys.platform == "darwin" else ru * 1024
|
||||
else:
|
||||
self._base_peak = self.start_rss
|
||||
|
||||
if self.torch_device:
|
||||
torch.cuda.reset_peak_memory_stats(self.torch_device)
|
||||
self.start_vram = torch.cuda.memory_allocated(self.torch_device)
|
||||
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
# end RSS
|
||||
self.end_time = time.time()
|
||||
self.end_rss: int = self._proc.memory_info().rss
|
||||
|
||||
# window‐peak RSS = max(0, end_peak – baseline_peak)
|
||||
if sys.platform == "win32":
|
||||
end_peak = self._proc.memory_info().peak_wset
|
||||
elif resource:
|
||||
ru = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss
|
||||
end_peak = ru if sys.platform == "darwin" else ru * 1024
|
||||
else:
|
||||
end_peak = self.end_rss
|
||||
|
||||
self.peak_rss = max(0, end_peak - self._base_peak)
|
||||
|
||||
# end & peak VRAM via torch
|
||||
if self.torch_device:
|
||||
self.end_vram = torch.cuda.memory_allocated(self.torch_device)
|
||||
peak_alloc = torch.cuda.max_memory_allocated(self.torch_device)
|
||||
self.peak_vram = max(0, peak_alloc - self.start_vram)
|
||||
|
||||
return False # don't suppress exceptions
|
||||
|
||||
@property
|
||||
def stats(self):
|
||||
out = {
|
||||
"start_time": self.start_time,
|
||||
"end_time": self.end_time,
|
||||
"elapsed": self.end_time - self.start_time,
|
||||
"start_rss": self.start_rss,
|
||||
"end_rss": self.end_rss,
|
||||
"peak_rss": self.peak_rss,
|
||||
}
|
||||
if self.torch_device:
|
||||
out.update(
|
||||
{
|
||||
"start_vram": self.start_vram,
|
||||
"end_vram": self.end_vram,
|
||||
"peak_vram": self.peak_vram,
|
||||
}
|
||||
)
|
||||
return out
|
||||
|
||||
|
||||
class RequiredConnectionException(Exception):
|
||||
"""Raised when an field which requires a connection did not receive a value."""
|
||||
|
||||
@@ -192,49 +277,55 @@ class BaseInvocation(ABC, BaseModel):
|
||||
"""Invoke with provided context and return outputs."""
|
||||
pass
|
||||
|
||||
def invoke_internal(self, context: InvocationContext, services: "InvocationServices") -> BaseInvocationOutput:
|
||||
def invoke_internal(self, context: InvocationContext, services: "InvocationServices") -> tuple[BaseInvocationOutput, MemoryTracker]:
|
||||
"""
|
||||
Internal invoke method, calls `invoke()` after some prep.
|
||||
Handles optional fields that are required to call `invoke()` and invocation cache.
|
||||
"""
|
||||
for field_name, field in self.model_fields.items():
|
||||
if not field.json_schema_extra or callable(field.json_schema_extra):
|
||||
# something has gone terribly awry, we should always have this and it should be a dict
|
||||
continue
|
||||
|
||||
# Here we handle the case where the field is optional in the pydantic class, but required
|
||||
# in the `invoke()` method.
|
||||
|
||||
orig_default = field.json_schema_extra.get("orig_default", PydanticUndefined)
|
||||
orig_required = field.json_schema_extra.get("orig_required", True)
|
||||
input_ = field.json_schema_extra.get("input", None)
|
||||
if orig_default is not PydanticUndefined and not hasattr(self, field_name):
|
||||
setattr(self, field_name, orig_default)
|
||||
if orig_required and orig_default is PydanticUndefined and getattr(self, field_name) is None:
|
||||
if input_ == Input.Connection:
|
||||
raise RequiredConnectionException(self.model_fields["type"].default, field_name)
|
||||
elif input_ == Input.Any:
|
||||
raise MissingInputException(self.model_fields["type"].default, field_name)
|
||||
|
||||
# skip node cache codepath if it's disabled
|
||||
if services.configuration.node_cache_size == 0:
|
||||
return self.invoke(context)
|
||||
|
||||
output: BaseInvocationOutput
|
||||
if self.use_cache:
|
||||
key = services.invocation_cache.create_key(self)
|
||||
cached_value = services.invocation_cache.get(key)
|
||||
if cached_value is None:
|
||||
services.logger.debug(f'Invocation cache miss for type "{self.get_type()}": {self.id}')
|
||||
|
||||
with MemoryTracker() as memory_tracker:
|
||||
for field_name, field in self.model_fields.items():
|
||||
if not field.json_schema_extra or callable(field.json_schema_extra):
|
||||
# something has gone terribly awry, we should always have this and it should be a dict
|
||||
continue
|
||||
|
||||
# Here we handle the case where the field is optional in the pydantic class, but required
|
||||
# in the `invoke()` method.
|
||||
|
||||
orig_default = field.json_schema_extra.get("orig_default", PydanticUndefined)
|
||||
orig_required = field.json_schema_extra.get("orig_required", True)
|
||||
input_ = field.json_schema_extra.get("input", None)
|
||||
if orig_default is not PydanticUndefined and not hasattr(self, field_name):
|
||||
setattr(self, field_name, orig_default)
|
||||
if orig_required and orig_default is PydanticUndefined and getattr(self, field_name) is None:
|
||||
if input_ == Input.Connection:
|
||||
raise RequiredConnectionException(self.model_fields["type"].default, field_name)
|
||||
elif input_ == Input.Any:
|
||||
raise MissingInputException(self.model_fields["type"].default, field_name)
|
||||
|
||||
# skip node cache codepath if it's disabled
|
||||
if services.configuration.node_cache_size == 0:
|
||||
output = self.invoke(context)
|
||||
services.invocation_cache.save(key, output)
|
||||
return output
|
||||
|
||||
if self.use_cache:
|
||||
key = services.invocation_cache.create_key(self)
|
||||
cached_value = services.invocation_cache.get(key)
|
||||
if cached_value is None:
|
||||
services.logger.debug(f'Invocation cache miss for type "{self.get_type()}": {self.id}')
|
||||
output = self.invoke(context)
|
||||
services.invocation_cache.save(key, output)
|
||||
else:
|
||||
services.logger.debug(f'Invocation cache hit for type "{self.get_type()}": {self.id}')
|
||||
output = cached_value
|
||||
else:
|
||||
services.logger.debug(f'Invocation cache hit for type "{self.get_type()}": {self.id}')
|
||||
return cached_value
|
||||
else:
|
||||
services.logger.debug(f'Skipping invocation cache for "{self.get_type()}": {self.id}')
|
||||
return self.invoke(context)
|
||||
services.logger.debug(f'Skipping invocation cache for "{self.get_type()}": {self.id}')
|
||||
output = self.invoke(context)
|
||||
|
||||
services.logger.info(memory_tracker.stats)
|
||||
|
||||
return (output, memory_tracker)
|
||||
|
||||
id: str = Field(
|
||||
default_factory=uuid_string,
|
||||
|
||||
@@ -126,7 +126,7 @@ class DefaultSessionRunner(SessionRunnerBase):
|
||||
)
|
||||
|
||||
# Invoke the node
|
||||
output = invocation.invoke_internal(context=context, services=self._services)
|
||||
(output, memory_tracker) = invocation.invoke_internal(context=context, services=self._services)
|
||||
# Save output and history
|
||||
queue_item.session.complete(invocation.id, output)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user