refactor(backend): use torchvision transforms for Kontext image preprocessing

Replace numpy-based normalization with torchvision transforms for
consistency with other image processing in the codebase
This commit is contained in:
psychedelicious
2025-08-04 21:08:11 +10:00
parent faf662d12e
commit df77a12efe

View File

@@ -1,6 +1,5 @@
import einops
import numpy as np
import torch
import torchvision.transforms as T
from einops import repeat
from PIL import Image
@@ -136,10 +135,17 @@ class KontextExtension:
# Use BICUBIC for smoother resizing to reduce artifacts
image = image.resize((final_width, final_height), Image.Resampling.BICUBIC)
# Convert to tensor with same normalization as BFL
image_np = np.array(image)
image_tensor = torch.from_numpy(image_np).float() / 127.5 - 1.0
image_tensor = einops.rearrange(image_tensor, "h w c -> 1 c h w")
# Convert to tensor using torchvision transforms for consistency
# This matches the normalization used in image_resized_to_grid_as_tensor
transformation = T.Compose(
[
T.ToTensor(), # Converts PIL image to tensor and scales to [0, 1]
]
)
image_tensor = transformation(image)
# Convert from [0, 1] to [-1, 1] range expected by VAE
image_tensor = image_tensor * 2.0 - 1.0
image_tensor = image_tensor.unsqueeze(0) # Add batch dimension
image_tensor = image_tensor.to(self._device)
# Continue with VAE encoding