mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
update UOp.SPECIAL arg spec [run_process_replay] (#5661)
* update UOp.SPECIAL arg spec [run_process_replay]
from `(0, "gid0", 4)` to just `("gid0", 4)`. closer to a Variable
* fix ptx
This commit is contained in:
@@ -783,7 +783,7 @@ class TestLinearizer(unittest.TestCase):
|
||||
idxs = get_grouped_dims(prefix, dims, max_sizes, reverse_dims)
|
||||
loop_idxs = dedup(flatten([[y for y in sorted(list(x.sparents)) if y.op is UOps.SPECIAL] for x in idxs]))
|
||||
loop_idxs = sorted(loop_idxs, key=lambda uop: uop.arg[0])
|
||||
sizes = [x.arg[2] for x in loop_idxs]
|
||||
sizes = [x.arg[1] for x in loop_idxs]
|
||||
assert len(idxs) == len(dims), f"expected idxs to have same length as dims {len(dims)}, got {len(idxs)}"
|
||||
assert len(loop_idxs) == min(len(sizes), len(dims)), f"expected idxs to have length {min(len(sizes), len(dims))}, got {len(loop_idxs)}"
|
||||
assert sizes == expected_sizes, f"expected sizes={expected_sizes}, got {sizes=}"
|
||||
@@ -840,9 +840,9 @@ class TestLinearizer(unittest.TestCase):
|
||||
k = helper_linearizer_opt(t+1)[0]
|
||||
idxs = dedup([uop for uop in k.uops if uop.op is UOps.SPECIAL])
|
||||
idxs = sorted(idxs, key=lambda uop: uop.arg[0])
|
||||
assert idxs[0].arg == (0, 'gidx0', 6), idxs[0].arg
|
||||
assert idxs[1].arg == (1, 'gidx1', 5), idxs[1].arg
|
||||
assert idxs[2].arg == (2, 'gidx2', 4), idxs[2].arg
|
||||
assert idxs[0].arg == ('gidx0', 6), idxs[0].arg
|
||||
assert idxs[1].arg == ('gidx1', 5), idxs[1].arg
|
||||
assert idxs[2].arg == ('gidx2', 4), idxs[2].arg
|
||||
|
||||
def test_div_collapse(self):
|
||||
def helper(t, msg, max_ops=0):
|
||||
|
||||
@@ -222,7 +222,7 @@ class TestUOpGraph(TestUOps):
|
||||
def test_fold_gated_load_local(self):
|
||||
glbl0 = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), (), (0, True))
|
||||
smem = UOp(UOps.DEFINE_LOCAL, PtrDType(dtypes.int), (), ("temp", 1))
|
||||
lidx = UOp(UOps.SPECIAL, dtypes.int, (), (0, "lidx1", 16))
|
||||
lidx = UOp(UOps.SPECIAL, dtypes.int, (), ("lidx0", 16))
|
||||
st = UOp(UOps.STORE, None, (smem, lidx, UOp.load(glbl0, lidx, dtype=dtypes.int)))
|
||||
barrier = UOp(UOps.BARRIER, None, (st, ))
|
||||
ld0 = UOp(UOps.LOAD, dtypes.int, (smem, lidx+1, UOp.const(dtypes.bool, False), UOp.const(dtypes.int, 2), barrier))
|
||||
|
||||
@@ -240,7 +240,7 @@ class TestGatedStoreRewrite(unittest.TestCase):
|
||||
@unittest.expectedFailure
|
||||
def test_tiny_gate_store(self):
|
||||
gmem = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), (), (0, True))
|
||||
gidx0 = UOp(UOps.SPECIAL, dtypes.int, (), (0, 'gidx0', 4))
|
||||
gidx0 = UOp(UOps.SPECIAL, dtypes.int, (), ('gidx0', 4))
|
||||
idx = gidx0 * UOp.const(dtypes.int, 2)
|
||||
val = UOp.const(dtypes.float, 42.0)
|
||||
gate = gidx0.lt(UOp.const(dtypes.int, 1))
|
||||
@@ -258,7 +258,7 @@ class TestGatedStoreRewrite(unittest.TestCase):
|
||||
def test_gate_some_stores(self):
|
||||
gmem0 = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), (), (0, True))
|
||||
gmem1 = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), (), (1, True))
|
||||
gidx0 = UOp(UOps.SPECIAL, dtypes.int, (), (0, 'gidx0', 4))
|
||||
gidx0 = UOp(UOps.SPECIAL, dtypes.int, (), ('gidx0', 4))
|
||||
idx = gidx0*UOp.const(dtypes.int, 2)
|
||||
val = UOp.const(dtypes.float, 42.0)
|
||||
gate = gidx0.lt(UOp.const(dtypes.int, 1))
|
||||
@@ -277,7 +277,7 @@ class TestGatedStoreRewrite(unittest.TestCase):
|
||||
def test_merge_ifs_alt(self):
|
||||
gmem0 = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), (), (0, True))
|
||||
gmem1 = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), (), (1, True))
|
||||
gidx0 = UOp(UOps.SPECIAL, dtypes.int, (), (0, 'gidx0', 4))
|
||||
gidx0 = UOp(UOps.SPECIAL, dtypes.int, (), ('gidx0', 4))
|
||||
idx = gidx0*UOp.const(dtypes.int, 2)
|
||||
val = UOp.const(dtypes.float, 42.0)
|
||||
gate = gidx0.lt(UOp.const(dtypes.int, 1))
|
||||
|
||||
Reference in New Issue
Block a user