add GROUP and GROUPTOP to test_arange (#9432)

it does not grow quadratically, but it's not 0 ops now
This commit is contained in:
chenyu
2025-03-13 11:28:38 -04:00
committed by GitHub
parent 90ffa9bd45
commit 99b0287e4e

View File

@@ -25,23 +25,29 @@ class TestArange(unittest.TestCase):
return p.estimates.ops
def test_complexity(self, opts=None, limit=None):
# add 1 to avoid divide by 0. arange is 0 flops now!
f1 = self._get_flops(256, opts) + 1
f2 = self._get_flops(2560, opts) + 1
f1 = self._get_flops(256, opts)
f2 = self._get_flops(2560, opts)
print(f"{f1=}, {f2=}")
assert (f1 < 6000 and f2 < 6000) or (f2 / f1 < 16), f"bad complexity, flops {f2/f1:.1f}X while inputs 10X"
# add 1 to avoid divide by 0. arange is 0 flops now!
assert (f1 < 6000 and f2 < 6000) or ((f2+1) / (f1+1) < 16), f"bad complexity, flops {(f2+1) / (f1+1):.1f}X while inputs 10X"
if limit is not None and not getenv("PTX"):
# PTX counts index ALU in flops
assert f1 <= limit, f"{f1=}, {limit=}"
def test_complexity_w_upcast(self): return self.test_complexity([Opt(OptOps.UPCAST, 0, 4)], limit=1)
def test_complexity_w_unroll2(self): return self.test_complexity([Opt(OptOps.UNROLL, 0, 2)], limit=1)
def test_complexity_w_unroll4(self): return self.test_complexity([Opt(OptOps.UNROLL, 0, 4)], limit=1)
def test_complexity_w_unroll8(self): return self.test_complexity([Opt(OptOps.UNROLL, 0, 8)], limit=1)
def test_complexity_w_upcast_and_unroll(self): return self.test_complexity([Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UNROLL, 0, 4)], limit=1)
def test_complexity_w_upcast(self): return self.test_complexity([Opt(OptOps.UPCAST, 0, 4)], limit=0)
def test_complexity_w_unroll2(self): return self.test_complexity([Opt(OptOps.UNROLL, 0, 2)], limit=0)
def test_complexity_w_unroll4(self): return self.test_complexity([Opt(OptOps.UNROLL, 0, 4)], limit=0)
def test_complexity_w_unroll8(self): return self.test_complexity([Opt(OptOps.UNROLL, 0, 8)], limit=0)
def test_complexity_w_upcast_and_unroll(self): return self.test_complexity([Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UNROLL, 0, 4)], limit=0)
@unittest.skip("doesn't work yet")
def test_complexity_w_local_and_padto(self): return self.test_complexity([Opt(OptOps.LOCAL, 0, 16), Opt(op=OptOps.PADTO, axis=1, arg=32)])
if Device.default.renderer.has_local:
# TODO: fix limit
def test_complexity_w_group(self): return self.test_complexity([Opt(OptOps.GROUP, 0, 16)], limit=81920)
def test_complexity_w_group_top(self): return self.test_complexity([Opt(OptOps.GROUPTOP, 0, 16)], limit=106496)
def test_complexity_w_local(self): return self.test_complexity([Opt(OptOps.LOCAL, 0, 16)], limit=0)
@unittest.skip("doesn't work yet")
def test_complexity_w_local_and_padto(self): return self.test_complexity([Opt(OptOps.LOCAL, 0, 16), Opt(OptOps.PADTO, axis=1, arg=32)])
def test_all_opts(self, opts=None, exclude=None):
k = Kernel(Tensor.arange(256).schedule()[-1].ast)