mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
revert layernorm to have axis param
This commit is contained in:
@@ -312,8 +312,7 @@ class Tensor:
|
||||
|
||||
def sequential(self, ll:List[Callable[[Tensor], Tensor]]): return functools.reduce(lambda x,f: f(x), ll, self)
|
||||
|
||||
def layernorm(self, eps=1e-5):
|
||||
axis = range(1, len(self.shape))
|
||||
def layernorm(self, axis=-1, eps=1e-5):
|
||||
y = (self - self.mean(axis=axis, keepdim=True))
|
||||
return y.div((y*y).mean(axis=axis, keepdim=True).add(eps).sqrt())
|
||||
|
||||
|
||||
Reference in New Issue
Block a user