Update model.py

This commit is contained in:
Jong Wook Kim
2024-09-30 10:23:39 -07:00
committed by GitHub
parent 65a353771a
commit 3211024b53

View File

@@ -2,7 +2,7 @@ import base64
import gzip
from contextlib import contextmanager
from dataclasses import dataclass
from typing import Dict, Iterable, Optional
from typing import Dict, Iterable, Optional, Tuple
import numpy as np
import torch
@@ -113,7 +113,7 @@ class MultiHeadAttention(nn.Module):
def qkv_attention(
self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
n_batch, n_ctx, n_state = q.shape
scale = (n_state // self.n_head) ** -0.25
q = q.view(*q.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)