Naive TorchAutocastContext.

This commit is contained in:
Ryan Dick
2024-12-03 14:55:43 +00:00
parent 401fb392b8
commit 030832f30b
3 changed files with 42 additions and 0 deletions

View File

@@ -0,0 +1,18 @@
import torch
from torch.utils._python_dispatch import TorchDispatchMode
def cast_to_device_and_run(func, args, kwargs, to_device: torch.device):
args_on_device = [a.to(to_device) if isinstance(a, torch.Tensor) else a for a in args]
kwargs_on_device = {k: v.to(to_device) if isinstance(v, torch.Tensor) else v for k, v in kwargs.items()}
return func(*args_on_device, **kwargs_on_device)
class TorchAutocastContext(TorchDispatchMode):
def __init__(self, to_device: torch.device):
self._to_device = to_device
def __torch_dispatch__(self, func, types, args, kwargs):
print(f"Dispatch Log: {func}(*{args}, **{kwargs})")
print(f"Dispatch Log: {types}")
return cast_to_device_and_run(func, args, kwargs, self._to_device)

View File

@@ -0,0 +1,24 @@
import torch
from invokeai.backend.model_cache_v2.torch_autocast_context import TorchAutocastContext
class DummyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear1 = torch.nn.Linear(10, 10)
self.linear2 = torch.nn.Linear(10, 10)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.linear1(x)
x = self.linear2(x)
return x
def test_torch_autocast_context():
model = DummyModule()
with TorchAutocastContext(to_device=torch.device("cuda")):
x = torch.randn(10, 10, device="cuda")
y = model(x)
print(y.shape)