mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-02-02 16:05:13 -05:00
10 lines
404 B
Python
10 lines
404 B
Python
import torch
|
|
|
|
|
|
def pad_with_zeros(orig_weight: torch.Tensor, target_shape: torch.Size) -> torch.Tensor:
|
|
"""Pad a weight tensor with zeros to match the target shape."""
|
|
expanded_weight = torch.zeros(target_shape, dtype=orig_weight.dtype, device=orig_weight.device)
|
|
slices = tuple(slice(0, dim) for dim in orig_weight.shape)
|
|
expanded_weight[slices] = orig_weight
|
|
return expanded_weight
|