minor change to gelu (#5048)

used `math.sqrt(2 / math.pi)` instead of `0.7978845608`, and moved one mul self inside parentheses. this matched the paper and llm.c
This commit is contained in:
chenyu
2024-06-18 22:26:56 -04:00
committed by GitHub
parent 4c7e316ded
commit 996788358d

View File

@@ -2232,7 +2232,7 @@ class Tensor:
print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).gelu().numpy())
```
"""
return 0.5 * self * (1 + (self * 0.7978845608 * (1 + 0.044715 * self * self)).tanh())
return 0.5 * self * (1 + (math.sqrt(2 / math.pi) * (self + 0.044715 * self ** 3)).tanh())
def quick_gelu(self):
"""