require renderer argument in get_program, removes device opening in process replay [pr] (#13524)

This commit is contained in:
qazal
2025-12-03 02:05:31 +08:00
committed by GitHub
parent 21184ae6b1
commit 366badaa68
10 changed files with 60 additions and 56 deletions

View File

@@ -45,7 +45,7 @@ class TestLinearizer(unittest.TestCase):
tst = Tensor.ones(16, dtype=dtypes.int).contiguous().realize()
out = tst.neg().cast(dtypes.char).cast(dtypes.int).cast(dtypes.char) * 2
ast = helper_linearizer_opt(out)
uops = get_program(ast, opts=[]).uops
uops = get_program(ast, renderer=Device[Device.DEFAULT].renderer, opts=[]).uops
self.assertEqual(len([x for x in uops if x.op is Ops.CAST]), 1)
@unittest.expectedFailure
@@ -53,7 +53,7 @@ class TestLinearizer(unittest.TestCase):
tst = Tensor.ones(16, dtype=dtypes.int).contiguous().realize()
out = tst.neg().cast(dtypes.char).cast(dtypes.int) * 2
ast = helper_linearizer_opt(out)
uops = get_program(ast, opts=[]).uops
uops = get_program(ast, renderer=Device[Device.DEFAULT].renderer, opts=[]).uops
self.assertEqual(len([x for x in uops if x.op is Ops.CAST]), 0)
@unittest.skipIf(isinstance(Device[Device.DEFAULT].renderer, PTXRenderer), "broken on ptx")
@@ -63,7 +63,7 @@ class TestLinearizer(unittest.TestCase):
b = Tensor.empty(16)
out = img.conv2d(w, b)
ast = helper_linearizer_opt(out)
uops = get_program(ast, opts=[]).uops
uops = get_program(ast, renderer=Device[Device.DEFAULT].renderer, opts=[]).uops
# slice at the last loop end
uslice = [i for i,u in enumerate(uops) if u.op == Ops.END][-1]
# only valid test if outermost range is the reduce
@@ -84,7 +84,7 @@ class TestLinearizer(unittest.TestCase):
a = Tensor.randn(2, ).realize()
out = a.reshape(2, 1).expand(2, 3).sum()
ast = helper_linearizer_opt(out, wanna_output=[np.broadcast_to(a.numpy().reshape(2, 1), (2, 3)).sum()])
uops = get_program(ast, opts=[]).uops
uops = get_program(ast, renderer=Device[Device.DEFAULT].renderer, opts=[]).uops
ranges = [i for i,u in enumerate(uops) if u.op is Ops.RANGE]
assert len(ranges) == 1 # NOTE: it collapses now
@@ -92,7 +92,7 @@ class TestLinearizer(unittest.TestCase):
a = Tensor.randn(2, ).realize()
out = a.reshape(2, 1).expand(2, 3).expand(2, 2, 3).sum()
ast = helper_linearizer_opt(out, wanna_output=[np.broadcast_to(np.broadcast_to(a.numpy().reshape(2, 1), (2, 3)), (2, 2, 3)).sum()])
uops = get_program(ast, opts=[]).uops
uops = get_program(ast, renderer=Device[Device.DEFAULT].renderer, opts=[]).uops
ranges = [i for i,u in enumerate(uops) if u.op is Ops.RANGE]
assert len(ranges) == 1 # NOTE: it collapses now
@@ -101,7 +101,7 @@ class TestLinearizer(unittest.TestCase):
a = Tensor([2, 2]).realize()
out = a.reshape(2, 1).pad(((1, 1), (1, 1)), value=2).sum()
ast = helper_linearizer_opt(out, wanna_output=[24])
uops = get_program(ast, opts=[]).uops
uops = get_program(ast, renderer=Device[Device.DEFAULT].renderer, opts=[]).uops
ranges = [i for i,u in enumerate(uops) if u.op is Ops.RANGE]
# RANGE -> ALU -> RANGE -> ALU + LOAD -> STORE
assert any(x.op in GroupOp.ALU for x in uops[ranges[0]:ranges[1]])
@@ -114,7 +114,7 @@ class TestLinearizer(unittest.TestCase):
b = Tensor.randn(1, 1).realize()
out = (a + b[0]).sum() + b[0]
ast = helper_linearizer_opt(out, wanna_output=[(a.numpy()+b.numpy()[0]).sum()+b.numpy()])
uops = get_program(ast, opts=[]).uops
uops = get_program(ast, renderer=Device[Device.DEFAULT].renderer, opts=[]).uops
ranges = [i for i,u in enumerate(uops) if u.op is Ops.RANGE]
# LOAD -> RANGE -> LOAD -> STORE
assert len([x for x in uops[:ranges[0]] if x.op is Ops.LOAD]) == 1
@@ -124,7 +124,7 @@ class TestLinearizer(unittest.TestCase):
b = Tensor.randn(1, 1).realize()
out = (a.reshape(2, 1).expand(2, 3) + b[0]).sum() + b[0]
ast = helper_linearizer_opt(out, wanna_output=[(np.broadcast_to(a.numpy().reshape(2, 1), (2, 3)) + b.numpy()[0]).sum() + b.numpy()])
uops = get_program(ast, opts=[]).uops
uops = get_program(ast, renderer=Device[Device.DEFAULT].renderer, opts=[]).uops
ranges = [i for i,u in enumerate(uops) if u.op is Ops.RANGE]
assert len(ranges) == 1 # NOTE: it collapses now
@@ -135,7 +135,7 @@ class TestLinearizer(unittest.TestCase):
# these are of size 3 to avoid float4 coalesce
r = a[:-1] + a[1:]
uops = get_program(r.schedule()[-1].ast, opts=[Opt(op=OptOps.UPCAST, axis=0, arg=0)]).uops
uops = get_program(r.schedule()[-1].ast, renderer=Device[Device.DEFAULT].renderer, opts=[Opt(op=OptOps.UPCAST, axis=0, arg=0)]).uops
num_loads = len([uop for uop in uops if uop.op is Ops.LOAD])
assert num_loads <= 4, "more load uops than needed"
assert num_loads >= 4, "unexpected number of uops, maybe this test needs updating?"
@@ -147,7 +147,7 @@ class TestLinearizer(unittest.TestCase):
a, b = Tensor.randn(1).realize(), Tensor.randn(1).realize()
r = a.expand([2]) + b.expand([2])
uops = get_program(r.schedule()[-1].ast, opts=[Opt(op=OptOps.UPCAST, axis=0, arg=0)]).uops
uops = get_program(r.schedule()[-1].ast, renderer=Device[Device.DEFAULT].renderer, opts=[Opt(op=OptOps.UPCAST, axis=0, arg=0)]).uops
num_ops = len([uop for uop in uops if uop.op in GroupOp.ALU])
assert num_ops <= 1, "more alu uops than needed"
@@ -156,7 +156,8 @@ class TestLinearizer(unittest.TestCase):
x, w = Tensor.randn((1,1,3)).realize(), Tensor.randn((1,1,2)).realize()
r = Tensor.conv2d(x,w,padding=1).relu()
uops = get_program(r.schedule()[-1].ast, opts=[Opt(op=OptOps.UPCAST, axis=0, arg=0), Opt(op=OptOps.UNROLL, axis=0, arg=0)]).uops
uops = get_program(r.schedule()[-1].ast, renderer=Device[Device.DEFAULT].renderer,
opts=[Opt(op=OptOps.UPCAST, axis=0, arg=0), Opt(op=OptOps.UNROLL, axis=0, arg=0)]).uops
accs = [u for u in uops if u.op is Ops.DEFINE_REG]
stores = [u for u in uops if u.op is Ops.STORE]
assert len(accs) == 0 # it's removed now
@@ -179,7 +180,7 @@ class TestLinearizer(unittest.TestCase):
x, y = Tensor.rand(1,128), Tensor.rand(128, 128)
r = (x@y).relu()
opts_to_apply = [Opt(op=OptOps.GROUP, axis=0, arg=8), Opt(op=OptOps.LOCAL, axis=0, arg=4), Opt(op=OptOps.UPCAST, axis=0, arg=4)]
program = get_program(r.schedule()[-1].ast, opts=opts_to_apply)
program = get_program(r.schedule()[-1].ast, renderer=Device[Device.DEFAULT].renderer, opts=opts_to_apply)
stores = [u for u in program.uops if u.op is Ops.STORE and u.src[0].dtype.addrspace != AddrSpace.REG]
@@ -194,7 +195,7 @@ class TestLinearizer(unittest.TestCase):
def test_zero_fold(self):
a, b = Tensor.randn(1).realize(), Tensor.randn(1).realize()
r = Tensor.stack(a, b)
uops = get_program(r.schedule()[-1].ast, opts=[Opt(op=OptOps.UPCAST, axis=0, arg=0)]).uops
uops = get_program(r.schedule()[-1].ast, renderer=Device[Device.DEFAULT].renderer, opts=[Opt(op=OptOps.UPCAST, axis=0, arg=0)]).uops
num_ops = len([uop for uop in uops if uop.op in GroupOp.ALU])
assert num_ops == 0, "more alu uops than needed"
@@ -204,14 +205,14 @@ class TestLinearizer(unittest.TestCase):
if is_dtype_supported(tensor_dtype) and is_dtype_supported(acc_dtype):
a = Tensor([1, 2, 3], dtype=tensor_dtype).sum()
realized_ast = a.schedule()[-1].ast
program = get_program(realized_ast, opts=[])
program = get_program(realized_ast, renderer=Device[Device.DEFAULT].renderer, opts=[])
local = [uop for uop in program.uops if uop.op is Ops.DEFINE_REG]
assert local[0].dtype.base == acc_dtype
def test_arg_acc_dtype(self):
def helper_arg_acc_dtype(c: Tensor, expected_dtype:DType):
realized_ast = c.schedule()[-1].ast
program = get_program(realized_ast, opts=[])
program = get_program(realized_ast, renderer=Device[Device.DEFAULT].renderer, opts=[])
local = [uop for uop in program.uops if uop.op is Ops.DEFINE_REG]
self.assertEqual(local[0].dtype.base, expected_dtype)
@@ -239,7 +240,7 @@ class TestLinearizer(unittest.TestCase):
opt = [Opt(OptOps.UNROLL, 0, 4), Opt(OptOps.UPCAST, 0, 4)]
ast = helper_linearizer_opt(r, [opt])
# the uops graph is DEFINE_REG -> 4x STORE 0.0 -> RANGE -> 4x ALU -> 4x STORE -> ENDRANGE
uops = get_program(ast, opts=opt).uops
uops = get_program(ast, renderer=Device[Device.DEFAULT].renderer, opts=opt).uops
begin_range = [i for i, x in enumerate(uops) if x.op is Ops.RANGE][-1]
end_range = [i for i, x in enumerate(uops) if x.op is Ops.END][0]
for i,u in enumerate(uops): print(i, u.op, [uops.index(s) for s in u.src], u.arg, u.dtype)
@@ -353,7 +354,7 @@ class TestLinearizer(unittest.TestCase):
# shrink so that the dims do not collapse
t = Tensor.ones(5, 6, 7).contiguous().realize().shrink(((0, 4), (0, 5), (0, 6)))
ast = helper_linearizer_opt(t+1)
uops = get_program(ast, opts=[]).uops
uops = get_program(ast, renderer=Device[Device.DEFAULT].renderer, opts=[]).uops
idxs = dedup([uop for uop in uops if uop.op is Ops.SPECIAL])
idxs = sorted(idxs, key=lambda uop: uop.arg)
assert (idxs[0].arg, idxs[0].src[0].arg) == ('gidx0', 6), idxs[0]
@@ -386,13 +387,13 @@ class TestLinearizer(unittest.TestCase):
sched_copy = sched[:]
run_schedule(sched)
np.testing.assert_equal(a.flatten().numpy(), [1.,1.,1.,1.,2.,2.,2.,2.,1.,1.,1.,1.,1.,1.,1.,1.])
program = get_program(sched_copy[-1].ast, opts=())
program = get_program(sched_copy[-1].ast, renderer=Device[Device.DEFAULT].renderer, opts=())
assert not any(u.op == Ops.WHERE for u in program.uops), "found where where where should be folded"
def test_phi_simplification(self):
def helper(t, max_ops=0):
ast = helper_linearizer_opt(t)
uops = get_program(ast).uops
uops = get_program(ast, renderer=Device[Device.DEFAULT].renderer).uops
# ignore kernel optimized IF statements for now
if if_op:=next((u for u in uops if u.op is Ops.IF), None):
uops = uops[:uops.index(if_op)]
@@ -425,7 +426,7 @@ class TestLinearizer(unittest.TestCase):
out = x.matmul(y)
with Context(TC=0):
ast = helper_linearizer_opt(out)
uops = get_program(ast).uops
uops = get_program(ast, renderer=Device[Device.DEFAULT].renderer).uops
# check that the float4 cast collapses
store_vals = [u.src[1] for u in uops if u.op is Ops.STORE and u.src[0].dtype.addrspace != AddrSpace.REG]
for val in store_vals:
@@ -436,7 +437,7 @@ class TestLinearizer(unittest.TestCase):
x = Tensor.randn((4,3,6,6)).realize()
out = x.flip((0,1)).contiguous()
ast = helper_linearizer_opt(out)
store_val = [u.src[1] for u in get_program(ast).uops if u.op is Ops.STORE][0]
store_val = [u.src[1] for u in get_program(ast, renderer=Device[Device.DEFAULT].renderer).uops if u.op is Ops.STORE][0]
assert store_val.dtype == dtypes.float.vec(4) and store_val.op is not Ops.VECTORIZE
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals")
@@ -449,7 +450,7 @@ class TestLinearizer(unittest.TestCase):
Opt(OptOps.UNROLL, 0, 4), Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UPCAST, 1, 2)] # upcast accs in both reduces
ast = helper_linearizer_opt(out, opts=[opt])
def get_recursive(uop): return set.union(set(uop.src), [uop], *[get_recursive(v) for v in uop.src])
uops = get_program(ast, opts=opt).uops
uops = get_program(ast, renderer=Device[Device.DEFAULT].renderer, opts=opt).uops
local_stores = [u for u in uops if u.op is Ops.STORE and any(x.op is Ops.DEFINE_LOCAL for x in get_recursive(u.src[0]))]
global_stores = [u for u in uops if u.op is Ops.STORE and any(x.op is Ops.DEFINE_GLOBAL for x in get_recursive(u.src[0]))]
barrier = [u for u in uops if u.op is Ops.BARRIER]
@@ -470,7 +471,7 @@ class TestLinearizer(unittest.TestCase):
x, y = Tensor.rand(1,128), Tensor.rand(128, 128)
r = (x@y).relu()
ast = helper_linearizer_opt(r)
uops = get_program(ast).uops
uops = get_program(ast, renderer=Device[Device.DEFAULT].renderer).uops
stores = [u for u in uops if u.op is Ops.STORE and u.src[0].dtype.addrspace != AddrSpace.REG]
# the float4 value stores directly in lds and we skip upcast
@@ -517,7 +518,7 @@ def _helper_linearizer_opt_ast(realized_ast:UOp, real_bufs:list[Buffer], opts=[]
device = real_bufs[0].device
wanna_output = [np.array(x).flatten() for x in wanna_output]
def get_prg(opts): return CompiledRunner(replace(get_program(realized_ast, opts=opts), device=device))
def get_prg(opts): return CompiledRunner(replace(get_program(realized_ast, renderer=Device[Device.DEFAULT].renderer, opts=opts), device=device))
def check_opt(opts):
prg = get_prg(opts=opts)