Match Torch speed for sum reduction on M1 (#1187)

* Add additional kernel when reducing multiple dimensions at once.

* Faster for smaller inputs

* Whitespace and naming

* Cleaner, guard for Metal only, and max 1 split rather than N

* Draft of different approach

* One additional kernel call for this test (as expected)
This commit is contained in:
Alexander Edwards
2023-07-20 02:18:58 +10:00
committed by GitHub
parent fde9f0e60d
commit 59af9b81c5
2 changed files with 15 additions and 2 deletions

View File

@@ -63,7 +63,7 @@ class TestInferenceMinKernels(unittest.TestCase):
for p in get_parameters(model): p.assign(np.zeros(p.shape, dtype=p.dtype.np))
img = Tensor.randn(1, 3, 224, 224)
# TODO: this seems very high
with CLCache(115):
with CLCache(116):
model.forward(img).realize()
def test_resnet(self):

View File

@@ -1,4 +1,5 @@
from __future__ import annotations
import math
import operator
from typing import Callable, Optional, Tuple, Union, List, Dict, Any, cast
import sys, importlib, inspect, functools, pathlib
@@ -208,11 +209,23 @@ class LazyBuffer:
return root.reshape(ret.st.shape)
return ret
def reduce_op(self:LazyBuffer, op:ReduceOps, new_shape:Tuple[int, ...]) -> LazyBuffer:
def _reduce_op(self:LazyBuffer, op:ReduceOps, new_shape:Tuple[int, ...]) -> LazyBuffer:
if self.shape == tuple(new_shape): return self
srcs = _push_movement_ops((self,)) if SHUFFLE_MOVEMENT_OPS else (self,)
return create_lazybuffer(self.device, ShapeTracker(new_shape), ReduceOps, LazyOp(op, srcs, new_shape), self.dtype)
def reduce_op(self:LazyBuffer, op:ReduceOps, new_shape:Tuple[int, ...]) -> LazyBuffer:
if prod(self.shape) // prod(new_shape) > 8192:
reduced_dimensions = [(i, math.gcd(256, old), stride) for i, (old, new, stride) in enumerate(zip(self.shape, new_shape, self.st.real_strides())) if old != new]
dimension_to_split, divisor, _ = max(reduced_dimensions, key=lambda v: v[1]//(v[2] or math.inf) ) # heuristic -> choose largest divisor to split on, penalize large strides
intermediate_input_shape = self.shape[:dimension_to_split] + (self.shape[dimension_to_split]//divisor, divisor) + self.shape[dimension_to_split+1:]
intermediate_output_shape = self.shape[:dimension_to_split] + (self.shape[dimension_to_split]//divisor, 1) + self.shape[dimension_to_split+1:]
final_input_shape = self.shape[:dimension_to_split] + (self.shape[dimension_to_split]//divisor,) + self.shape[dimension_to_split+1:]
return self.reshape(intermediate_input_shape)._reduce_op(op, intermediate_output_shape).reshape(final_input_shape)._reduce_op(op, new_shape)
return self._reduce_op(op, new_shape)
def reshape(self:LazyBuffer, arg:Tuple[int, ...]) -> LazyBuffer:
if self.shape == arg: return self
if not self.realized and self.op.op == MovementOps.RESHAPE: