update UOp.SPECIAL arg spec [run_process_replay] (#5661)

* update UOp.SPECIAL arg spec [run_process_replay]

from `(0, "gid0", 4)` to just `("gid0", 4)`. closer to a Variable

* fix ptx
This commit is contained in:
chenyu
2024-07-23 16:58:12 -04:00
committed by GitHub
parent 4d47968580
commit 16c27ae400
10 changed files with 28 additions and 28 deletions

View File

@@ -783,7 +783,7 @@ class TestLinearizer(unittest.TestCase):
idxs = get_grouped_dims(prefix, dims, max_sizes, reverse_dims)
loop_idxs = dedup(flatten([[y for y in sorted(list(x.sparents)) if y.op is UOps.SPECIAL] for x in idxs]))
loop_idxs = sorted(loop_idxs, key=lambda uop: uop.arg[0])
sizes = [x.arg[2] for x in loop_idxs]
sizes = [x.arg[1] for x in loop_idxs]
assert len(idxs) == len(dims), f"expected idxs to have same length as dims {len(dims)}, got {len(idxs)}"
assert len(loop_idxs) == min(len(sizes), len(dims)), f"expected idxs to have length {min(len(sizes), len(dims))}, got {len(loop_idxs)}"
assert sizes == expected_sizes, f"expected sizes={expected_sizes}, got {sizes=}"
@@ -840,9 +840,9 @@ class TestLinearizer(unittest.TestCase):
k = helper_linearizer_opt(t+1)[0]
idxs = dedup([uop for uop in k.uops if uop.op is UOps.SPECIAL])
idxs = sorted(idxs, key=lambda uop: uop.arg[0])
assert idxs[0].arg == (0, 'gidx0', 6), idxs[0].arg
assert idxs[1].arg == (1, 'gidx1', 5), idxs[1].arg
assert idxs[2].arg == (2, 'gidx2', 4), idxs[2].arg
assert idxs[0].arg == ('gidx0', 6), idxs[0].arg
assert idxs[1].arg == ('gidx1', 5), idxs[1].arg
assert idxs[2].arg == ('gidx2', 4), idxs[2].arg
def test_div_collapse(self):
def helper(t, msg, max_ops=0):

View File

@@ -222,7 +222,7 @@ class TestUOpGraph(TestUOps):
def test_fold_gated_load_local(self):
glbl0 = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), (), (0, True))
smem = UOp(UOps.DEFINE_LOCAL, PtrDType(dtypes.int), (), ("temp", 1))
lidx = UOp(UOps.SPECIAL, dtypes.int, (), (0, "lidx1", 16))
lidx = UOp(UOps.SPECIAL, dtypes.int, (), ("lidx0", 16))
st = UOp(UOps.STORE, None, (smem, lidx, UOp.load(glbl0, lidx, dtype=dtypes.int)))
barrier = UOp(UOps.BARRIER, None, (st, ))
ld0 = UOp(UOps.LOAD, dtypes.int, (smem, lidx+1, UOp.const(dtypes.bool, False), UOp.const(dtypes.int, 2), barrier))

View File

@@ -240,7 +240,7 @@ class TestGatedStoreRewrite(unittest.TestCase):
@unittest.expectedFailure
def test_tiny_gate_store(self):
gmem = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), (), (0, True))
gidx0 = UOp(UOps.SPECIAL, dtypes.int, (), (0, 'gidx0', 4))
gidx0 = UOp(UOps.SPECIAL, dtypes.int, (), ('gidx0', 4))
idx = gidx0 * UOp.const(dtypes.int, 2)
val = UOp.const(dtypes.float, 42.0)
gate = gidx0.lt(UOp.const(dtypes.int, 1))
@@ -258,7 +258,7 @@ class TestGatedStoreRewrite(unittest.TestCase):
def test_gate_some_stores(self):
gmem0 = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), (), (0, True))
gmem1 = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), (), (1, True))
gidx0 = UOp(UOps.SPECIAL, dtypes.int, (), (0, 'gidx0', 4))
gidx0 = UOp(UOps.SPECIAL, dtypes.int, (), ('gidx0', 4))
idx = gidx0*UOp.const(dtypes.int, 2)
val = UOp.const(dtypes.float, 42.0)
gate = gidx0.lt(UOp.const(dtypes.int, 1))
@@ -277,7 +277,7 @@ class TestGatedStoreRewrite(unittest.TestCase):
def test_merge_ifs_alt(self):
gmem0 = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), (), (0, True))
gmem1 = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), (), (1, True))
gidx0 = UOp(UOps.SPECIAL, dtypes.int, (), (0, 'gidx0', 4))
gidx0 = UOp(UOps.SPECIAL, dtypes.int, (), ('gidx0', 4))
idx = gidx0*UOp.const(dtypes.int, 2)
val = UOp.const(dtypes.float, 42.0)
gate = gidx0.lt(UOp.const(dtypes.int, 1))