change Tensor.stack to method (#4719)

This commit is contained in:
chenyu
2024-05-24 17:04:19 -04:00
committed by GitHub
parent ba116ff630
commit 31358cbea5
15 changed files with 40 additions and 42 deletions

View File

@@ -85,7 +85,7 @@ def selective_scan_ref(
if i == u.shape[2] - 1:
last_state = x
ys.append(y)
y = Tensor.stack(ys, dim=2) # (batch dim L)
y = Tensor.stack(*ys, dim=2) # (batch dim L)
out = y if D is None else y + u * D.reshape((-1, 1))
if z is not None:
out = out * z.silu()