Move Ops.SPECIAL arg to src (#11918)

* initial moving bound to src

* arg to src

* remove import

* fixup linearizer

* arg to src

* fix test_uop_graph

* fix more tests

* fix python renderer

* get const value from const uop

* ssimplify uop estimates

* fix webgpu locals

* fix old test

* gate Ops.SPECIAL in linearizer

* use ssimplify() for local/global_size

* remove toposort gate_parents_instead_of_self

* fix rendering in comment

* cleanup

* rename and add comments

* add BottomUpGate with test
This commit is contained in:
Sieds Lykles
2025-09-04 09:31:44 +02:00
committed by GitHub
parent 5cf42dc4db
commit 572a3c15c6
20 changed files with 114 additions and 76 deletions

View File

@@ -46,7 +46,7 @@ class TestRendererFailures(unittest.TestCase):
@unittest.skipIf(not isinstance(Device[Device.DEFAULT].renderer, (PTXRenderer, PythonRenderer)), "test is for ptx or python renderer")
def test_gated_store_with_alu(self):
a = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0)
gate_alu = (lidx0:=UOp(Ops.SPECIAL, dtypes.int, (), ('lidx0', 4))).ne(0)
gate_alu = (lidx0:=UOp(Ops.SPECIAL, dtypes.int, (UOp.const(dtypes.int, 4),), 'lidx0')).ne(0)
gated_alu_store = UOp(Ops.STORE, dtypes.void, (a.index(lidx0, gate_alu), UOp.const(dtypes.int, 1)))
sink = UOp(Ops.SINK, dtypes.void, (gated_alu_store,))
uops = full_rewrite(sink, Device[Device.DEFAULT].renderer)
@@ -56,8 +56,8 @@ class TestRendererFailures(unittest.TestCase):
@unittest.skipIf(not isinstance(Device[Device.DEFAULT].renderer, (PTXRenderer, PythonRenderer)), "test is for ptx or python renderer")
def test_gated_store_with_alu_2d(self):
a = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0)
gate_alu_0 = (lidx0:=UOp(Ops.SPECIAL, dtypes.int, (), ('lidx0', 4))).ne(0)
gate_alu_1 = (lidx1:=UOp(Ops.SPECIAL, dtypes.int, (), ('lidx1', 2))).ne(0)
gate_alu_0 = (lidx0:=UOp(Ops.SPECIAL, dtypes.int, (UOp.const(dtypes.int, 4),), 'lidx0')).ne(0)
gate_alu_1 = (lidx1:=UOp(Ops.SPECIAL, dtypes.int, (UOp.const(dtypes.int, 2),), 'lidx1')).ne(0)
gated_alu_store = UOp(Ops.STORE, dtypes.void, (a.index(lidx0+lidx1*4, gate_alu_0&gate_alu_1), UOp.const(dtypes.int, 1)))
sink = UOp(Ops.SINK, dtypes.void, (gated_alu_store,))
uops = full_rewrite(sink, Device[Device.DEFAULT].renderer)
@@ -101,7 +101,7 @@ class TestPTXFailures(unittest.TestCase):
@unittest.skip("INDEX can only have a gate ALU parent, not an IF")
def test_gated_store_with_if(self):
a = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0)
gate_alu = (lidx0:=UOp(Ops.SPECIAL, dtypes.int, (), ('lidx0', 4))).ne(0)
gate_alu = (lidx0:=UOp(Ops.SPECIAL, dtypes.int, (UOp.const(dtypes.int, 4),), 'lidx0')).ne(0)
val = UOp.const(dtypes.int, 1)
if_uop = UOp(Ops.IF, dtypes.void, (gate_alu,))
gated_alu_store = UOp(Ops.STORE, dtypes.void, (a.index(lidx0, if_uop), val))