fixed batched clip computation (#6292)

This commit is contained in:
Tobias Fischer
2024-08-26 20:48:15 -04:00
committed by GitHub
parent 3918f6eea0
commit 211bfb6d8a

View File

@@ -451,9 +451,9 @@ class OpenClipEncoder:
def get_clip_score(self, tokens:Tensor, image:Tensor) -> Tensor:
image_features: Tensor = self.visual(image)
image_features /= image_features.square().sum([-1,-2], keepdim=True).sqrt() # Frobenius Norm
image_features /= image_features.square().sum(-1, keepdim=True).sqrt() # Frobenius Norm
text_features = self.encode_tokens(tokens).squeeze(0)
text_features /= text_features.square().sum([-1,-2], keepdim=True).sqrt() # Frobenius Norm
text_features = self.encode_tokens(tokens)
text_features /= text_features.square().sum(-1, keepdim=True).sqrt() # Frobenius Norm
return image_features @ text_features.T
return (image_features * text_features).sum(axis=-1)