mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 23:18:04 -05:00
force reduce to be in axis order (#10837)
* force reduce to be in axis order * disable rule causing loop * disable that rule * no ra there * only move non reduce * fix tests
This commit is contained in:
@@ -554,7 +554,7 @@ class TestLinearizer(unittest.TestCase):
|
||||
idxs = Tensor([0,3,5,6]).realize()
|
||||
with Context(FUSE_ARANGE=1):
|
||||
sink = dataset[idxs].contiguous().kernelize().uop.base.src[1].arg.ast
|
||||
real_index = dataset.numpy()[idxs.numpy()].reshape(4, 1, 256, 1)
|
||||
real_index = dataset.numpy()[idxs.numpy()].reshape(4, 256, 1, 1)
|
||||
helper_linearizer_ast(sink, [dataset, idxs], wanna_output=[real_index])
|
||||
|
||||
# AssertionError: repeated stores in uops
|
||||
|
||||
@@ -1622,7 +1622,7 @@ class TestSchedule(unittest.TestCase):
|
||||
run_schedule(check_schedule(out, 3)) # TODO: push a reduceop through a reshape
|
||||
|
||||
def test_conv2d(self): _test_conv2d(7)
|
||||
def test_conv2d_fused(self): _test_conv2d(6, FUSE_CONV_BW=1)
|
||||
def test_conv2d_fused(self): _test_conv2d(5, FUSE_CONV_BW=1)
|
||||
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.half) and is_dtype_supported(dtypes.ulong), "need half and ulong")
|
||||
def test_conv2d_half(self): _test_conv2d(7, dtype=dtypes.half)
|
||||
|
||||
@@ -241,7 +241,7 @@ view_right = merge_views+PatternMatcher([
|
||||
# apply view after reduceops
|
||||
(UPat(Ops.REDUCE_AXIS, src=(UPat(Ops.VIEW, src=(UPat(GroupOp.All-ALWAYS_CONTIGUOUS, name="src"),), name="v"),), name="r"), reduceop_view_right),
|
||||
# apply view after elementwise ops
|
||||
(UPat(GroupOp.All-{Ops.SINK, Ops.GBARRIER}, name="root"), elementwise_view_right),
|
||||
(UPat(GroupOp.All-{Ops.SINK, Ops.GBARRIER, Ops.REDUCE_AXIS}, name="root"), elementwise_view_right),
|
||||
# merge axes for double reduce (invert of SPLIT_REDUCEOP=1)
|
||||
(UPat(Ops.REDUCE_AXIS, src=(UPat(Ops.REDUCE_AXIS, name="r1"),), name="r2"),
|
||||
lambda r1,r2: r1.replace(arg=(r1.arg[0], r2.arg[1]+r1.arg[1])) if r1.arg[0] is r2.arg[0] else None),
|
||||
|
||||
@@ -253,7 +253,18 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||
def range(dtype:DType, end:sint, idx:int): return UOp(Ops.RANGE, dtype=dtype, src=(sint_to_uop(end),), arg=idx)
|
||||
def r(self, op:Ops, axis:tuple[int, ...]):
|
||||
axis = tuple(sorted([x for x in axis if resolve(self.shape[x] != 1)]))
|
||||
return self if len(axis) == 0 else UOp(Ops.REDUCE_AXIS, self.dtype, (self,), (op, axis))
|
||||
if len(axis) == 0: return self
|
||||
# move any non reduce axis before the first reduce axis
|
||||
move_early = [i for i in range(axis[0], len(self.shape)) if i not in axis and resolve(self.shape[i] != 1)]
|
||||
if move_early:
|
||||
permute = tuple(range(axis[0])) + tuple(move_early) + tuple([i for i in range(axis[0], len(self.shape)) if i not in move_early])
|
||||
ret = self.permute(permute)
|
||||
new_axis = tuple([x for x in range(axis[0]+len(move_early), len(self.shape)) if resolve(ret.shape[x] != 1)])
|
||||
assert len(axis) == len(new_axis)
|
||||
else:
|
||||
ret, new_axis = self, axis
|
||||
ret = UOp(Ops.REDUCE_AXIS, self.dtype, (ret,), (op, new_axis))
|
||||
return ret.reshape(tuple([x if i not in axis else 1 for i,x in enumerate(self.shape)]))
|
||||
def assign(self, x:UOp): return UOp(Ops.ASSIGN, self.dtype, (self,x))
|
||||
def reduce(self, *src:UOp, **kwargs): return UOp(Ops.REDUCE, kwargs.pop('dtype', self.dtype), src=(self,)+src, **kwargs)
|
||||
def contiguous(self): return self.alu(Ops.CONTIGUOUS)
|
||||
|
||||
Reference in New Issue
Block a user