diff --git a/extra/models/clip.py b/extra/models/clip.py index dee7cd6401..d8c6879456 100644 --- a/extra/models/clip.py +++ b/extra/models/clip.py @@ -451,9 +451,9 @@ class OpenClipEncoder: def get_clip_score(self, tokens:Tensor, image:Tensor) -> Tensor: image_features: Tensor = self.visual(image) - image_features /= image_features.square().sum([-1,-2], keepdim=True).sqrt() # Frobenius Norm + image_features /= image_features.square().sum(-1, keepdim=True).sqrt() # Frobenius Norm - text_features = self.encode_tokens(tokens).squeeze(0) - text_features /= text_features.square().sum([-1,-2], keepdim=True).sqrt() # Frobenius Norm + text_features = self.encode_tokens(tokens) + text_features /= text_features.square().sum(-1, keepdim=True).sqrt() # Frobenius Norm - return image_features @ text_features.T + return (image_features * text_features).sum(axis=-1)