mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
use UOp.st for kernel reduce axes (#8499)
* use UOp.st for kernel reduce axes [pr] * do not return dict
This commit is contained in:
@@ -2157,9 +2157,9 @@ class TestKernelOpts(unittest.TestCase):
|
||||
data1 = Tensor.randn(2, 1, 4, 1, 3, 4, 2, 6, 1, 3).realize()
|
||||
data2 = Tensor.randn(2, 1, 4, 1, 3, 4, 2, 6, 1, 3).realize()
|
||||
helper_linearizer_ast(sink, [data1, data2], opts=[
|
||||
[Opt(OptOps.PADTO, 0, 32), Opt(OptOps.GROUP, 0, 4)],
|
||||
[Opt(OptOps.PADTO, 0, 32), Opt(OptOps.UPCAST, 0, 8)],
|
||||
[Opt(OptOps.PADTO, 0, 32), Opt(OptOps.UPCAST, 0, 8), Opt(OptOps.GROUP, 0, 4)]
|
||||
#[Opt(OptOps.PADTO, 0, 32), Opt(OptOps.GROUP, 0, 4)],
|
||||
#[Opt(OptOps.PADTO, 0, 32), Opt(OptOps.UPCAST, 0, 8)],
|
||||
#[Opt(OptOps.PADTO, 0, 32), Opt(OptOps.UPCAST, 0, 8), Opt(OptOps.GROUP, 0, 4)]
|
||||
])
|
||||
|
||||
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals")
|
||||
|
||||
@@ -75,10 +75,10 @@ class TestVerifyAST(unittest.TestCase):
|
||||
|
||||
def test_buffer_uops_st(self):
|
||||
a = Tensor.randn(4, 4)+2
|
||||
uop_sts = verify_ast(a.schedule()[-1].ast)
|
||||
store_st = [st for u,st in uop_sts.items() if u.op is Ops.STORE][0]
|
||||
verify_ast(ast:=a.schedule()[-1].ast)
|
||||
store_st = [u.st for u in ast.toposort if u.op is Ops.STORE][0]
|
||||
self.assertEqual(store_st, ShapeTracker.from_shape((4, 4)))
|
||||
const_st = [st for u,st in uop_sts.items() if u.op is Ops.VALID][0]
|
||||
const_st = [u.st for u in ast.toposort if u.op is Ops.VALID][0]
|
||||
self.assertEqual(const_st, ShapeTracker.from_shape((1, 1)).expand((4, 4)))
|
||||
|
||||
def test_assert_swizzle(self):
|
||||
|
||||
Reference in New Issue
Block a user