mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
fixed batched clip computation (#6292)
This commit is contained in:
@@ -451,9 +451,9 @@ class OpenClipEncoder:
|
||||
|
||||
def get_clip_score(self, tokens:Tensor, image:Tensor) -> Tensor:
|
||||
image_features: Tensor = self.visual(image)
|
||||
image_features /= image_features.square().sum([-1,-2], keepdim=True).sqrt() # Frobenius Norm
|
||||
image_features /= image_features.square().sum(-1, keepdim=True).sqrt() # Frobenius Norm
|
||||
|
||||
text_features = self.encode_tokens(tokens).squeeze(0)
|
||||
text_features /= text_features.square().sum([-1,-2], keepdim=True).sqrt() # Frobenius Norm
|
||||
text_features = self.encode_tokens(tokens)
|
||||
text_features /= text_features.square().sum(-1, keepdim=True).sqrt() # Frobenius Norm
|
||||
|
||||
return image_features @ text_features.T
|
||||
return (image_features * text_features).sum(axis=-1)
|
||||
|
||||
Reference in New Issue
Block a user