mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 23:48:01 -05:00
clean the long lines in avg_pool2d and max_pool2d (#5091)
This commit is contained in:
@@ -1623,8 +1623,9 @@ class Tensor:
|
||||
print(t.avg_pool2d().numpy())
|
||||
```
|
||||
"""
|
||||
return self._pool(
|
||||
make_pair(kernel_size), stride if stride is not None else kernel_size, dilation).mean(axis=tuple(range(0-len(make_pair(kernel_size)), 0)))
|
||||
kernel_size = make_pair(kernel_size)
|
||||
return self._pool(kernel_size, stride if stride is not None else kernel_size, dilation).mean(axis=tuple(range(-len(kernel_size), 0)))
|
||||
|
||||
def max_pool2d(self, kernel_size=(2,2), stride=None, dilation=1):
|
||||
"""
|
||||
Applies max pooling over a tensor.
|
||||
@@ -1638,8 +1639,8 @@ class Tensor:
|
||||
print(t.max_pool2d().numpy())
|
||||
```
|
||||
"""
|
||||
return self._pool(
|
||||
make_pair(kernel_size), stride if stride is not None else kernel_size, dilation).max(axis=tuple(range(0-len(make_pair(kernel_size)), 0)))
|
||||
kernel_size = make_pair(kernel_size)
|
||||
return self._pool(kernel_size, stride if stride is not None else kernel_size, dilation).max(axis=tuple(range(-len(kernel_size), 0)))
|
||||
|
||||
def conv2d(self, weight:Tensor, bias:Optional[Tensor]=None, groups=1, stride=1, dilation=1, padding=0, acc_dtype:Optional[DType]=None) -> Tensor:
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user