mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-23 03:00:31 -04:00
Naive TorchAutocastContext.
This commit is contained in:
0
invokeai/backend/model_cache_v2/__init__.py
Normal file
0
invokeai/backend/model_cache_v2/__init__.py
Normal file
18
invokeai/backend/model_cache_v2/torch_autocast_context.py
Normal file
18
invokeai/backend/model_cache_v2/torch_autocast_context.py
Normal file
@@ -0,0 +1,18 @@
|
||||
import torch
|
||||
from torch.utils._python_dispatch import TorchDispatchMode
|
||||
|
||||
|
||||
def cast_to_device_and_run(func, args, kwargs, to_device: torch.device):
|
||||
args_on_device = [a.to(to_device) if isinstance(a, torch.Tensor) else a for a in args]
|
||||
kwargs_on_device = {k: v.to(to_device) if isinstance(v, torch.Tensor) else v for k, v in kwargs.items()}
|
||||
return func(*args_on_device, **kwargs_on_device)
|
||||
|
||||
|
||||
class TorchAutocastContext(TorchDispatchMode):
|
||||
def __init__(self, to_device: torch.device):
|
||||
self._to_device = to_device
|
||||
|
||||
def __torch_dispatch__(self, func, types, args, kwargs):
|
||||
print(f"Dispatch Log: {func}(*{args}, **{kwargs})")
|
||||
print(f"Dispatch Log: {types}")
|
||||
return cast_to_device_and_run(func, args, kwargs, self._to_device)
|
||||
24
tests/backend/model_cache_v2/test_torch_autocast_context.py
Normal file
24
tests/backend/model_cache_v2/test_torch_autocast_context.py
Normal file
@@ -0,0 +1,24 @@
|
||||
import torch
|
||||
|
||||
from invokeai.backend.model_cache_v2.torch_autocast_context import TorchAutocastContext
|
||||
|
||||
|
||||
class DummyModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.linear1 = torch.nn.Linear(10, 10)
|
||||
self.linear2 = torch.nn.Linear(10, 10)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = self.linear1(x)
|
||||
x = self.linear2(x)
|
||||
return x
|
||||
|
||||
|
||||
def test_torch_autocast_context():
|
||||
model = DummyModule()
|
||||
|
||||
with TorchAutocastContext(to_device=torch.device("cuda")):
|
||||
x = torch.randn(10, 10, device="cuda")
|
||||
y = model(x)
|
||||
print(y.shape)
|
||||
Reference in New Issue
Block a user