use UOp.st for kernel reduce axes (#8499)

* use UOp.st for kernel reduce axes [pr]

* do not return dict
This commit is contained in:
qazal
2025-01-13 06:24:11 -05:00
committed by GitHub
parent 7562cc0399
commit 586e730d32
3 changed files with 11 additions and 12 deletions

View File

@@ -2157,9 +2157,9 @@ class TestKernelOpts(unittest.TestCase):
data1 = Tensor.randn(2, 1, 4, 1, 3, 4, 2, 6, 1, 3).realize()
data2 = Tensor.randn(2, 1, 4, 1, 3, 4, 2, 6, 1, 3).realize()
helper_linearizer_ast(sink, [data1, data2], opts=[
[Opt(OptOps.PADTO, 0, 32), Opt(OptOps.GROUP, 0, 4)],
[Opt(OptOps.PADTO, 0, 32), Opt(OptOps.UPCAST, 0, 8)],
[Opt(OptOps.PADTO, 0, 32), Opt(OptOps.UPCAST, 0, 8), Opt(OptOps.GROUP, 0, 4)]
#[Opt(OptOps.PADTO, 0, 32), Opt(OptOps.GROUP, 0, 4)],
#[Opt(OptOps.PADTO, 0, 32), Opt(OptOps.UPCAST, 0, 8)],
#[Opt(OptOps.PADTO, 0, 32), Opt(OptOps.UPCAST, 0, 8), Opt(OptOps.GROUP, 0, 4)]
])
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals")

View File

@@ -75,10 +75,10 @@ class TestVerifyAST(unittest.TestCase):
def test_buffer_uops_st(self):
a = Tensor.randn(4, 4)+2
uop_sts = verify_ast(a.schedule()[-1].ast)
store_st = [st for u,st in uop_sts.items() if u.op is Ops.STORE][0]
verify_ast(ast:=a.schedule()[-1].ast)
store_st = [u.st for u in ast.toposort if u.op is Ops.STORE][0]
self.assertEqual(store_st, ShapeTracker.from_shape((4, 4)))
const_st = [st for u,st in uop_sts.items() if u.op is Ops.VALID][0]
const_st = [u.st for u in ast.toposort if u.op is Ops.VALID][0]
self.assertEqual(const_st, ShapeTracker.from_shape((1, 1)).expand((4, 4)))
def test_assert_swizzle(self):