mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-06 21:53:53 -05:00
require renderer argument in get_program, removes device opening in process replay [pr] (#13524)
This commit is contained in:
@@ -14,7 +14,6 @@ try:
|
||||
from tinygrad.uop.ops import UOp, Ops, KernelInfo
|
||||
from tinygrad.codegen.opt import Opt
|
||||
from tinygrad.helpers import VERSION, Context, ContextVar, colored, db_connection, getenv, tqdm, BEAM
|
||||
from tinygrad.device import Device
|
||||
except ImportError as e:
|
||||
print(repr(e))
|
||||
exit(int(ASSERT_DIFF))
|
||||
@@ -52,12 +51,10 @@ def replay_get_rangeify_map(ret:dict[UOp, UOp], big_sink:UOp) -> tuple[str, str,
|
||||
return "\n".join([f"{len(asts)} kernels", *asts])
|
||||
return to_str(new_sink), to_str(big_sink.substitute(ret)), (big_sink,)
|
||||
|
||||
def replay_get_program(p:ProgramSpec, ast:UOp, renderer:Renderer|None=None, opts:list[Opt]|None=None) -> tuple[str, str, tuple[Any, ...]]:
|
||||
def replay_get_program(p:ProgramSpec, ast:UOp, renderer:Renderer, opts:list[Opt]|None=None) -> tuple[str, str, tuple[Any, ...]]:
|
||||
# the ast.arg is non None if we are inside of search.py
|
||||
sink_arg = ast.arg or KernelInfo(opts_to_apply=tuple(opts) if opts is not None else p.applied_opts if BEAM>=1 else None)
|
||||
input_ast = ast.replace(arg=replace(sink_arg, name=p.name))
|
||||
# if no renderer was provided, open the device to get it
|
||||
if renderer is None: renderer = Device[p.device].renderer
|
||||
p2 = get_program(input_ast, renderer=renderer)
|
||||
def to_str(ret:ProgramSpec) -> str:
|
||||
# PYTHON renderer pickles UOps, first unpickle and decode here
|
||||
|
||||
@@ -24,7 +24,7 @@ class TestFloat4(unittest.TestCase):
|
||||
s = c.schedule()[0]
|
||||
realized_ast = s.ast
|
||||
opts_to_apply = [Opt(op=OptOps.UPCAST, axis=0, arg=4)]
|
||||
program = get_program(realized_ast, Device[Device.DEFAULT].renderer, opts=opts_to_apply)
|
||||
program = get_program(realized_ast, renderer=Device[Device.DEFAULT].renderer, opts=opts_to_apply)
|
||||
|
||||
assert TestFloat4.count_float4(program.uops) == (2, 1)
|
||||
|
||||
@@ -35,7 +35,8 @@ class TestFloat4(unittest.TestCase):
|
||||
c = a + b
|
||||
|
||||
s = c.schedule()[0]
|
||||
uops = get_program(s.ast, opts=[Opt(op=OptOps.UPCAST, axis=0, arg=4), Opt(op=OptOps.UPCAST, axis=0, arg=2)]).uops
|
||||
uops = get_program(s.ast, renderer=Device[Device.DEFAULT].renderer,
|
||||
opts=[Opt(op=OptOps.UPCAST, axis=0, arg=4), Opt(op=OptOps.UPCAST, axis=0, arg=2)]).uops
|
||||
assert TestFloat4.count_float4(uops) == (4, 2)
|
||||
|
||||
@unittest.skipUnless(Device.DEFAULT in {"CPU"} and AMX, "Only CPU with AMX upcasts float up to size 16")
|
||||
@@ -46,7 +47,8 @@ class TestFloat4(unittest.TestCase):
|
||||
c = a + b
|
||||
|
||||
s = c.schedule()[0]
|
||||
return get_program(s.ast, opts=[Opt(op=OptOps.UPCAST, axis=0, arg=4), Opt(op=OptOps.UPCAST, axis=0, arg=shift)]).uops
|
||||
return get_program(s.ast, renderer=Device[Device.DEFAULT].renderer,
|
||||
opts=[Opt(op=OptOps.UPCAST, axis=0, arg=4), Opt(op=OptOps.UPCAST, axis=0, arg=shift)]).uops
|
||||
|
||||
sizes = [12, 8, 16]
|
||||
shifts = [3, 2, 4]
|
||||
@@ -64,7 +66,7 @@ class TestFloat4(unittest.TestCase):
|
||||
s = c.schedule()[0]
|
||||
realized_ast = s.ast
|
||||
opts_to_apply = [Opt(op=OptOps.UPCAST, axis=0, arg=4)]
|
||||
program = get_program(realized_ast, Device[Device.DEFAULT].renderer, opts=opts_to_apply)
|
||||
program = get_program(realized_ast, renderer=Device[Device.DEFAULT].renderer, opts=opts_to_apply)
|
||||
|
||||
assert TestFloat4.count_float4(program.uops) == (0, 1)
|
||||
|
||||
@@ -75,7 +77,8 @@ class TestFloat4(unittest.TestCase):
|
||||
c = a + b
|
||||
|
||||
s = c.schedule()[0]
|
||||
uops = get_program(s.ast, opts=[Opt(op=OptOps.UPCAST, axis=1, arg=4), Opt(op=OptOps.UPCAST, axis=1, arg=2)]).uops
|
||||
uops = get_program(s.ast, renderer=Device[Device.DEFAULT].renderer,
|
||||
opts=[Opt(op=OptOps.UPCAST, axis=1, arg=4), Opt(op=OptOps.UPCAST, axis=1, arg=2)]).uops
|
||||
|
||||
assert TestFloat4.count_float4(uops) == (0, 2)
|
||||
|
||||
@@ -87,7 +90,8 @@ class TestFloat4(unittest.TestCase):
|
||||
c = a + b
|
||||
|
||||
s = c.schedule()[0]
|
||||
return get_program(s.ast, opts=[Opt(op=OptOps.UPCAST, axis=1, arg=4), Opt(op=OptOps.UPCAST, axis=1, arg=shift)]).uops
|
||||
return get_program(s.ast, renderer=Device[Device.DEFAULT].renderer,
|
||||
opts=[Opt(op=OptOps.UPCAST, axis=1, arg=4), Opt(op=OptOps.UPCAST, axis=1, arg=shift)]).uops
|
||||
|
||||
sizes = [13, 9, 17]
|
||||
shifts = [3, 2, 4]
|
||||
@@ -105,7 +109,7 @@ class TestFloat4(unittest.TestCase):
|
||||
# float4 should be emitted (the reduce axis of size 4 is the float4 axis here)
|
||||
|
||||
s = c.schedule()[0]
|
||||
uops = get_program(s.ast, opts=[Opt(op=OptOps.UNROLL, axis=0, arg=4)]).uops
|
||||
uops = get_program(s.ast, renderer=Device[Device.DEFAULT].renderer, opts=[Opt(op=OptOps.UNROLL, axis=0, arg=4)]).uops
|
||||
|
||||
assert TestFloat4.count_float4(uops) == (0, 0)
|
||||
|
||||
@@ -119,7 +123,8 @@ class TestFloat4(unittest.TestCase):
|
||||
# UPDATE: now we do this fusion
|
||||
|
||||
s = c.schedule()[0]
|
||||
uops = get_program(s.ast, opts=[Opt(op=OptOps.UPCAST, axis=0, arg=0), Opt(op=OptOps.UNROLL, axis=0, arg=0)]).uops
|
||||
uops = get_program(s.ast, renderer=Device[Device.DEFAULT].renderer,
|
||||
opts=[Opt(op=OptOps.UPCAST, axis=0, arg=0), Opt(op=OptOps.UNROLL, axis=0, arg=0)]).uops
|
||||
|
||||
assert TestFloat4.count_float4(uops) in {(0,1), (1,1)}
|
||||
|
||||
@@ -132,7 +137,7 @@ class TestFloat4(unittest.TestCase):
|
||||
# since the top axis is not contiguous.
|
||||
|
||||
s = c.schedule()[0]
|
||||
uops = get_program(s.ast, opts=[Opt(op=OptOps.UPCAST, axis=0, arg=4)]).uops
|
||||
uops = get_program(s.ast, renderer=Device[Device.DEFAULT].renderer, opts=[Opt(op=OptOps.UPCAST, axis=0, arg=4)]).uops
|
||||
|
||||
assert TestFloat4.count_float4(uops) == (0, 1)
|
||||
|
||||
@@ -144,7 +149,7 @@ class TestFloat4(unittest.TestCase):
|
||||
# should float4 b but not a
|
||||
|
||||
s = c.schedule()[0]
|
||||
uops = get_program(s.ast, opts=[Opt(op=OptOps.UPCAST, axis=0, arg=4)]).uops
|
||||
uops = get_program(s.ast, renderer=Device[Device.DEFAULT].renderer, opts=[Opt(op=OptOps.UPCAST, axis=0, arg=4)]).uops
|
||||
|
||||
assert TestFloat4.count_float4(uops) == (1, 1)
|
||||
|
||||
|
||||
@@ -13,7 +13,7 @@ class TestArange(unittest.TestCase):
|
||||
GlobalCounters.reset()
|
||||
sched = tensor.schedule()
|
||||
self.assertEqual(len(sched), 1)
|
||||
p = get_program(sched[-1].ast)
|
||||
p = get_program(sched[-1].ast, renderer=Device[Device.DEFAULT].renderer)
|
||||
ExecItem(CompiledRunner(p), [tensor.uop.buffer]).run()
|
||||
np.testing.assert_equal(tensor.numpy(), desired)
|
||||
return p.estimates.ops
|
||||
@@ -36,7 +36,7 @@ class TestArange(unittest.TestCase):
|
||||
with Context(NOOPT=1):
|
||||
t = Tensor.ones(256, 256).contiguous().realize()
|
||||
sched = t.triu().schedule()
|
||||
p = get_program(sched[-1].ast)
|
||||
p = get_program(sched[-1].ast, renderer=Device[Device.DEFAULT].renderer)
|
||||
self.assertLessEqual(Estimates.from_uops(p.uops).ops, 4 * 256 * 256)
|
||||
|
||||
DSET, DDIM = 2048, 32
|
||||
|
||||
@@ -45,7 +45,7 @@ class TestLinearizer(unittest.TestCase):
|
||||
tst = Tensor.ones(16, dtype=dtypes.int).contiguous().realize()
|
||||
out = tst.neg().cast(dtypes.char).cast(dtypes.int).cast(dtypes.char) * 2
|
||||
ast = helper_linearizer_opt(out)
|
||||
uops = get_program(ast, opts=[]).uops
|
||||
uops = get_program(ast, renderer=Device[Device.DEFAULT].renderer, opts=[]).uops
|
||||
self.assertEqual(len([x for x in uops if x.op is Ops.CAST]), 1)
|
||||
|
||||
@unittest.expectedFailure
|
||||
@@ -53,7 +53,7 @@ class TestLinearizer(unittest.TestCase):
|
||||
tst = Tensor.ones(16, dtype=dtypes.int).contiguous().realize()
|
||||
out = tst.neg().cast(dtypes.char).cast(dtypes.int) * 2
|
||||
ast = helper_linearizer_opt(out)
|
||||
uops = get_program(ast, opts=[]).uops
|
||||
uops = get_program(ast, renderer=Device[Device.DEFAULT].renderer, opts=[]).uops
|
||||
self.assertEqual(len([x for x in uops if x.op is Ops.CAST]), 0)
|
||||
|
||||
@unittest.skipIf(isinstance(Device[Device.DEFAULT].renderer, PTXRenderer), "broken on ptx")
|
||||
@@ -63,7 +63,7 @@ class TestLinearizer(unittest.TestCase):
|
||||
b = Tensor.empty(16)
|
||||
out = img.conv2d(w, b)
|
||||
ast = helper_linearizer_opt(out)
|
||||
uops = get_program(ast, opts=[]).uops
|
||||
uops = get_program(ast, renderer=Device[Device.DEFAULT].renderer, opts=[]).uops
|
||||
# slice at the last loop end
|
||||
uslice = [i for i,u in enumerate(uops) if u.op == Ops.END][-1]
|
||||
# only valid test if outermost range is the reduce
|
||||
@@ -84,7 +84,7 @@ class TestLinearizer(unittest.TestCase):
|
||||
a = Tensor.randn(2, ).realize()
|
||||
out = a.reshape(2, 1).expand(2, 3).sum()
|
||||
ast = helper_linearizer_opt(out, wanna_output=[np.broadcast_to(a.numpy().reshape(2, 1), (2, 3)).sum()])
|
||||
uops = get_program(ast, opts=[]).uops
|
||||
uops = get_program(ast, renderer=Device[Device.DEFAULT].renderer, opts=[]).uops
|
||||
ranges = [i for i,u in enumerate(uops) if u.op is Ops.RANGE]
|
||||
assert len(ranges) == 1 # NOTE: it collapses now
|
||||
|
||||
@@ -92,7 +92,7 @@ class TestLinearizer(unittest.TestCase):
|
||||
a = Tensor.randn(2, ).realize()
|
||||
out = a.reshape(2, 1).expand(2, 3).expand(2, 2, 3).sum()
|
||||
ast = helper_linearizer_opt(out, wanna_output=[np.broadcast_to(np.broadcast_to(a.numpy().reshape(2, 1), (2, 3)), (2, 2, 3)).sum()])
|
||||
uops = get_program(ast, opts=[]).uops
|
||||
uops = get_program(ast, renderer=Device[Device.DEFAULT].renderer, opts=[]).uops
|
||||
ranges = [i for i,u in enumerate(uops) if u.op is Ops.RANGE]
|
||||
assert len(ranges) == 1 # NOTE: it collapses now
|
||||
|
||||
@@ -101,7 +101,7 @@ class TestLinearizer(unittest.TestCase):
|
||||
a = Tensor([2, 2]).realize()
|
||||
out = a.reshape(2, 1).pad(((1, 1), (1, 1)), value=2).sum()
|
||||
ast = helper_linearizer_opt(out, wanna_output=[24])
|
||||
uops = get_program(ast, opts=[]).uops
|
||||
uops = get_program(ast, renderer=Device[Device.DEFAULT].renderer, opts=[]).uops
|
||||
ranges = [i for i,u in enumerate(uops) if u.op is Ops.RANGE]
|
||||
# RANGE -> ALU -> RANGE -> ALU + LOAD -> STORE
|
||||
assert any(x.op in GroupOp.ALU for x in uops[ranges[0]:ranges[1]])
|
||||
@@ -114,7 +114,7 @@ class TestLinearizer(unittest.TestCase):
|
||||
b = Tensor.randn(1, 1).realize()
|
||||
out = (a + b[0]).sum() + b[0]
|
||||
ast = helper_linearizer_opt(out, wanna_output=[(a.numpy()+b.numpy()[0]).sum()+b.numpy()])
|
||||
uops = get_program(ast, opts=[]).uops
|
||||
uops = get_program(ast, renderer=Device[Device.DEFAULT].renderer, opts=[]).uops
|
||||
ranges = [i for i,u in enumerate(uops) if u.op is Ops.RANGE]
|
||||
# LOAD -> RANGE -> LOAD -> STORE
|
||||
assert len([x for x in uops[:ranges[0]] if x.op is Ops.LOAD]) == 1
|
||||
@@ -124,7 +124,7 @@ class TestLinearizer(unittest.TestCase):
|
||||
b = Tensor.randn(1, 1).realize()
|
||||
out = (a.reshape(2, 1).expand(2, 3) + b[0]).sum() + b[0]
|
||||
ast = helper_linearizer_opt(out, wanna_output=[(np.broadcast_to(a.numpy().reshape(2, 1), (2, 3)) + b.numpy()[0]).sum() + b.numpy()])
|
||||
uops = get_program(ast, opts=[]).uops
|
||||
uops = get_program(ast, renderer=Device[Device.DEFAULT].renderer, opts=[]).uops
|
||||
ranges = [i for i,u in enumerate(uops) if u.op is Ops.RANGE]
|
||||
assert len(ranges) == 1 # NOTE: it collapses now
|
||||
|
||||
@@ -135,7 +135,7 @@ class TestLinearizer(unittest.TestCase):
|
||||
# these are of size 3 to avoid float4 coalesce
|
||||
r = a[:-1] + a[1:]
|
||||
|
||||
uops = get_program(r.schedule()[-1].ast, opts=[Opt(op=OptOps.UPCAST, axis=0, arg=0)]).uops
|
||||
uops = get_program(r.schedule()[-1].ast, renderer=Device[Device.DEFAULT].renderer, opts=[Opt(op=OptOps.UPCAST, axis=0, arg=0)]).uops
|
||||
num_loads = len([uop for uop in uops if uop.op is Ops.LOAD])
|
||||
assert num_loads <= 4, "more load uops than needed"
|
||||
assert num_loads >= 4, "unexpected number of uops, maybe this test needs updating?"
|
||||
@@ -147,7 +147,7 @@ class TestLinearizer(unittest.TestCase):
|
||||
a, b = Tensor.randn(1).realize(), Tensor.randn(1).realize()
|
||||
r = a.expand([2]) + b.expand([2])
|
||||
|
||||
uops = get_program(r.schedule()[-1].ast, opts=[Opt(op=OptOps.UPCAST, axis=0, arg=0)]).uops
|
||||
uops = get_program(r.schedule()[-1].ast, renderer=Device[Device.DEFAULT].renderer, opts=[Opt(op=OptOps.UPCAST, axis=0, arg=0)]).uops
|
||||
num_ops = len([uop for uop in uops if uop.op in GroupOp.ALU])
|
||||
assert num_ops <= 1, "more alu uops than needed"
|
||||
|
||||
@@ -156,7 +156,8 @@ class TestLinearizer(unittest.TestCase):
|
||||
x, w = Tensor.randn((1,1,3)).realize(), Tensor.randn((1,1,2)).realize()
|
||||
r = Tensor.conv2d(x,w,padding=1).relu()
|
||||
|
||||
uops = get_program(r.schedule()[-1].ast, opts=[Opt(op=OptOps.UPCAST, axis=0, arg=0), Opt(op=OptOps.UNROLL, axis=0, arg=0)]).uops
|
||||
uops = get_program(r.schedule()[-1].ast, renderer=Device[Device.DEFAULT].renderer,
|
||||
opts=[Opt(op=OptOps.UPCAST, axis=0, arg=0), Opt(op=OptOps.UNROLL, axis=0, arg=0)]).uops
|
||||
accs = [u for u in uops if u.op is Ops.DEFINE_REG]
|
||||
stores = [u for u in uops if u.op is Ops.STORE]
|
||||
assert len(accs) == 0 # it's removed now
|
||||
@@ -179,7 +180,7 @@ class TestLinearizer(unittest.TestCase):
|
||||
x, y = Tensor.rand(1,128), Tensor.rand(128, 128)
|
||||
r = (x@y).relu()
|
||||
opts_to_apply = [Opt(op=OptOps.GROUP, axis=0, arg=8), Opt(op=OptOps.LOCAL, axis=0, arg=4), Opt(op=OptOps.UPCAST, axis=0, arg=4)]
|
||||
program = get_program(r.schedule()[-1].ast, opts=opts_to_apply)
|
||||
program = get_program(r.schedule()[-1].ast, renderer=Device[Device.DEFAULT].renderer, opts=opts_to_apply)
|
||||
|
||||
stores = [u for u in program.uops if u.op is Ops.STORE and u.src[0].dtype.addrspace != AddrSpace.REG]
|
||||
|
||||
@@ -194,7 +195,7 @@ class TestLinearizer(unittest.TestCase):
|
||||
def test_zero_fold(self):
|
||||
a, b = Tensor.randn(1).realize(), Tensor.randn(1).realize()
|
||||
r = Tensor.stack(a, b)
|
||||
uops = get_program(r.schedule()[-1].ast, opts=[Opt(op=OptOps.UPCAST, axis=0, arg=0)]).uops
|
||||
uops = get_program(r.schedule()[-1].ast, renderer=Device[Device.DEFAULT].renderer, opts=[Opt(op=OptOps.UPCAST, axis=0, arg=0)]).uops
|
||||
num_ops = len([uop for uop in uops if uop.op in GroupOp.ALU])
|
||||
assert num_ops == 0, "more alu uops than needed"
|
||||
|
||||
@@ -204,14 +205,14 @@ class TestLinearizer(unittest.TestCase):
|
||||
if is_dtype_supported(tensor_dtype) and is_dtype_supported(acc_dtype):
|
||||
a = Tensor([1, 2, 3], dtype=tensor_dtype).sum()
|
||||
realized_ast = a.schedule()[-1].ast
|
||||
program = get_program(realized_ast, opts=[])
|
||||
program = get_program(realized_ast, renderer=Device[Device.DEFAULT].renderer, opts=[])
|
||||
local = [uop for uop in program.uops if uop.op is Ops.DEFINE_REG]
|
||||
assert local[0].dtype.base == acc_dtype
|
||||
|
||||
def test_arg_acc_dtype(self):
|
||||
def helper_arg_acc_dtype(c: Tensor, expected_dtype:DType):
|
||||
realized_ast = c.schedule()[-1].ast
|
||||
program = get_program(realized_ast, opts=[])
|
||||
program = get_program(realized_ast, renderer=Device[Device.DEFAULT].renderer, opts=[])
|
||||
local = [uop for uop in program.uops if uop.op is Ops.DEFINE_REG]
|
||||
self.assertEqual(local[0].dtype.base, expected_dtype)
|
||||
|
||||
@@ -239,7 +240,7 @@ class TestLinearizer(unittest.TestCase):
|
||||
opt = [Opt(OptOps.UNROLL, 0, 4), Opt(OptOps.UPCAST, 0, 4)]
|
||||
ast = helper_linearizer_opt(r, [opt])
|
||||
# the uops graph is DEFINE_REG -> 4x STORE 0.0 -> RANGE -> 4x ALU -> 4x STORE -> ENDRANGE
|
||||
uops = get_program(ast, opts=opt).uops
|
||||
uops = get_program(ast, renderer=Device[Device.DEFAULT].renderer, opts=opt).uops
|
||||
begin_range = [i for i, x in enumerate(uops) if x.op is Ops.RANGE][-1]
|
||||
end_range = [i for i, x in enumerate(uops) if x.op is Ops.END][0]
|
||||
for i,u in enumerate(uops): print(i, u.op, [uops.index(s) for s in u.src], u.arg, u.dtype)
|
||||
@@ -353,7 +354,7 @@ class TestLinearizer(unittest.TestCase):
|
||||
# shrink so that the dims do not collapse
|
||||
t = Tensor.ones(5, 6, 7).contiguous().realize().shrink(((0, 4), (0, 5), (0, 6)))
|
||||
ast = helper_linearizer_opt(t+1)
|
||||
uops = get_program(ast, opts=[]).uops
|
||||
uops = get_program(ast, renderer=Device[Device.DEFAULT].renderer, opts=[]).uops
|
||||
idxs = dedup([uop for uop in uops if uop.op is Ops.SPECIAL])
|
||||
idxs = sorted(idxs, key=lambda uop: uop.arg)
|
||||
assert (idxs[0].arg, idxs[0].src[0].arg) == ('gidx0', 6), idxs[0]
|
||||
@@ -386,13 +387,13 @@ class TestLinearizer(unittest.TestCase):
|
||||
sched_copy = sched[:]
|
||||
run_schedule(sched)
|
||||
np.testing.assert_equal(a.flatten().numpy(), [1.,1.,1.,1.,2.,2.,2.,2.,1.,1.,1.,1.,1.,1.,1.,1.])
|
||||
program = get_program(sched_copy[-1].ast, opts=())
|
||||
program = get_program(sched_copy[-1].ast, renderer=Device[Device.DEFAULT].renderer, opts=())
|
||||
assert not any(u.op == Ops.WHERE for u in program.uops), "found where where where should be folded"
|
||||
|
||||
def test_phi_simplification(self):
|
||||
def helper(t, max_ops=0):
|
||||
ast = helper_linearizer_opt(t)
|
||||
uops = get_program(ast).uops
|
||||
uops = get_program(ast, renderer=Device[Device.DEFAULT].renderer).uops
|
||||
# ignore kernel optimized IF statements for now
|
||||
if if_op:=next((u for u in uops if u.op is Ops.IF), None):
|
||||
uops = uops[:uops.index(if_op)]
|
||||
@@ -425,7 +426,7 @@ class TestLinearizer(unittest.TestCase):
|
||||
out = x.matmul(y)
|
||||
with Context(TC=0):
|
||||
ast = helper_linearizer_opt(out)
|
||||
uops = get_program(ast).uops
|
||||
uops = get_program(ast, renderer=Device[Device.DEFAULT].renderer).uops
|
||||
# check that the float4 cast collapses
|
||||
store_vals = [u.src[1] for u in uops if u.op is Ops.STORE and u.src[0].dtype.addrspace != AddrSpace.REG]
|
||||
for val in store_vals:
|
||||
@@ -436,7 +437,7 @@ class TestLinearizer(unittest.TestCase):
|
||||
x = Tensor.randn((4,3,6,6)).realize()
|
||||
out = x.flip((0,1)).contiguous()
|
||||
ast = helper_linearizer_opt(out)
|
||||
store_val = [u.src[1] for u in get_program(ast).uops if u.op is Ops.STORE][0]
|
||||
store_val = [u.src[1] for u in get_program(ast, renderer=Device[Device.DEFAULT].renderer).uops if u.op is Ops.STORE][0]
|
||||
assert store_val.dtype == dtypes.float.vec(4) and store_val.op is not Ops.VECTORIZE
|
||||
|
||||
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals")
|
||||
@@ -449,7 +450,7 @@ class TestLinearizer(unittest.TestCase):
|
||||
Opt(OptOps.UNROLL, 0, 4), Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UPCAST, 1, 2)] # upcast accs in both reduces
|
||||
ast = helper_linearizer_opt(out, opts=[opt])
|
||||
def get_recursive(uop): return set.union(set(uop.src), [uop], *[get_recursive(v) for v in uop.src])
|
||||
uops = get_program(ast, opts=opt).uops
|
||||
uops = get_program(ast, renderer=Device[Device.DEFAULT].renderer, opts=opt).uops
|
||||
local_stores = [u for u in uops if u.op is Ops.STORE and any(x.op is Ops.DEFINE_LOCAL for x in get_recursive(u.src[0]))]
|
||||
global_stores = [u for u in uops if u.op is Ops.STORE and any(x.op is Ops.DEFINE_GLOBAL for x in get_recursive(u.src[0]))]
|
||||
barrier = [u for u in uops if u.op is Ops.BARRIER]
|
||||
@@ -470,7 +471,7 @@ class TestLinearizer(unittest.TestCase):
|
||||
x, y = Tensor.rand(1,128), Tensor.rand(128, 128)
|
||||
r = (x@y).relu()
|
||||
ast = helper_linearizer_opt(r)
|
||||
uops = get_program(ast).uops
|
||||
uops = get_program(ast, renderer=Device[Device.DEFAULT].renderer).uops
|
||||
stores = [u for u in uops if u.op is Ops.STORE and u.src[0].dtype.addrspace != AddrSpace.REG]
|
||||
|
||||
# the float4 value stores directly in lds and we skip upcast
|
||||
@@ -517,7 +518,7 @@ def _helper_linearizer_opt_ast(realized_ast:UOp, real_bufs:list[Buffer], opts=[]
|
||||
device = real_bufs[0].device
|
||||
wanna_output = [np.array(x).flatten() for x in wanna_output]
|
||||
|
||||
def get_prg(opts): return CompiledRunner(replace(get_program(realized_ast, opts=opts), device=device))
|
||||
def get_prg(opts): return CompiledRunner(replace(get_program(realized_ast, renderer=Device[Device.DEFAULT].renderer, opts=opts), device=device))
|
||||
|
||||
def check_opt(opts):
|
||||
prg = get_prg(opts=opts)
|
||||
|
||||
@@ -3,6 +3,7 @@ import unittest
|
||||
from tinygrad.uop.ops import UOp, Ops, AxisType
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.engine.realize import get_program
|
||||
from tinygrad.device import Device
|
||||
|
||||
class TestLinearizerFailures(unittest.TestCase):
|
||||
def test_fail_1(self):
|
||||
@@ -18,7 +19,7 @@ class TestLinearizerFailures(unittest.TestCase):
|
||||
c9 = ((((c6+(c8*UOp.const(dtypes.float, -1.0)))*(c6+(c8*UOp.const(dtypes.float, -1.0)))).reduce(c5, arg=Ops.ADD)*UOp.const(dtypes.float, 0.000390625))+UOp.const(dtypes.float, 1e-05)).sqrt().reciprocal()
|
||||
c10 = c0.index(c3).store(c9).end(c1, c2)
|
||||
ast = c10.sink()
|
||||
get_program(ast)
|
||||
get_program(ast, renderer=Device[Device.DEFAULT].renderer)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
unittest.main()
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import numpy as np
|
||||
import unittest
|
||||
from tinygrad import Tensor
|
||||
from tinygrad import Tensor, Device
|
||||
from tinygrad.helpers import get_single_element
|
||||
from tinygrad.codegen.opt import Opt, OptOps
|
||||
from tinygrad.engine.realize import CompiledRunner, ExecItem, get_program
|
||||
@@ -17,7 +17,7 @@ class TestOptGemm(unittest.TestCase):
|
||||
t = self.a.T @ self.b.T
|
||||
# TODO: this should be a generic test helper
|
||||
si = get_single_element(t.schedule())
|
||||
run = CompiledRunner(get_program(si.ast, opts=opts))
|
||||
run = CompiledRunner(get_program(si.ast, renderer=Device[Device.DEFAULT].renderer, opts=opts))
|
||||
ExecItem(run, si.bufs).run()
|
||||
test = si.bufs[0].numpy().reshape(self.res.shape)
|
||||
np.testing.assert_allclose(self.res, test, atol=1e-4)
|
||||
|
||||
@@ -13,7 +13,7 @@ class TestOpts(unittest.TestCase):
|
||||
s = out.schedule()
|
||||
self.assertEqual(s[-1].ast.arg.opts_to_apply, opts)
|
||||
if Device.DEFAULT in {"CPU", "CL", "METAL"} and not CPU_LLVM and not CPU_LVP:
|
||||
prg = get_program(s[-1].ast)
|
||||
prg = get_program(s[-1].ast, renderer=Device[Device.DEFAULT].renderer)
|
||||
self.assertIn('float4', prg.src)
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
@@ -38,7 +38,7 @@ def create_gemm_model(model_path:str, batch_size=N, in_size=N, out_size=N, bias=
|
||||
|
||||
def sexec(out:Tensor, opts:list[Opt], replace_src=None, run_count=3):
|
||||
si = out.schedule()[-1]
|
||||
prg = get_program(si.ast, opts=opts)
|
||||
prg = get_program(si.ast, renderer=Device[Device.DEFAULT].renderer, opts=opts)
|
||||
if replace_src is not None:
|
||||
old_name = prg.src.split("__attribute__((noinline)) void ")[1].split("(")[0]
|
||||
prg = replace(prg, src=replace_src + "/* DSP boilerplate */" + prg.src.split("/* DSP boilerplate */")[1].replace(old_name, "fxn"))
|
||||
|
||||
@@ -175,13 +175,13 @@ class TestStatsOptimized(unittest.TestCase):
|
||||
self.assertEqual(p.estimates.mem, 3*N*N*4) # 3 NxN mats with floats
|
||||
|
||||
def test_gemm(self):
|
||||
p = get_program(self.ast_gemm, opts=[])
|
||||
p = get_program(self.ast_gemm, renderer=Device[Device.DEFAULT].renderer, opts=[])
|
||||
self.check_gemm(p)
|
||||
self.assertEqual(p.estimates.lds, 2*N*N*N*4 + 4*N*N)
|
||||
|
||||
def test_gemm_tc_unroll(self):
|
||||
try:
|
||||
p = get_program(self.ast_gemm, opts=[Opt(OptOps.TC, 0, (-1, 0, 1)), Opt(OptOps.UNROLL, 0, 2)])
|
||||
p = get_program(self.ast_gemm, renderer=Device[Device.DEFAULT].renderer, opts=[Opt(OptOps.TC, 0, (-1, 0, 1)), Opt(OptOps.UNROLL, 0, 2)])
|
||||
except KernelOptError:
|
||||
raise unittest.SkipTest("no tensor cores")
|
||||
print(p.src)
|
||||
@@ -190,18 +190,19 @@ class TestStatsOptimized(unittest.TestCase):
|
||||
# this is a good lesson about why UPCASTing is a good idea
|
||||
|
||||
def test_gemm_one_upcasted(self):
|
||||
p = get_program(self.ast_gemm, opts=[Opt(OptOps.UPCAST, 0, 4)])
|
||||
p = get_program(self.ast_gemm, renderer=Device[Device.DEFAULT].renderer, opts=[Opt(OptOps.UPCAST, 0, 4)])
|
||||
self.check_gemm(p)
|
||||
self.assertEqual(p.estimates.lds, N*N*N*4 + N*N*N*4//4 + 4*N*N)
|
||||
|
||||
def test_gemm_upcasted(self):
|
||||
p = get_program(self.ast_gemm, opts=[Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UPCAST, 1, 4), Opt(OptOps.UNROLL, 0, 4)])
|
||||
p = get_program(self.ast_gemm, renderer=Device[Device.DEFAULT].renderer,
|
||||
opts=[Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UPCAST, 1, 4), Opt(OptOps.UNROLL, 0, 4)])
|
||||
self.check_gemm(p)
|
||||
self.assertEqual(p.estimates.lds, 2*N*N*N*4//4 + 4*N*N)
|
||||
|
||||
def test_gemm_upcasted_locals(self):
|
||||
try:
|
||||
p = get_program(self.ast_gemm, opts=[Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UPCAST, 1, 4),
|
||||
p = get_program(self.ast_gemm, renderer=Device[Device.DEFAULT].renderer, opts=[Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UPCAST, 1, 4),
|
||||
Opt(OptOps.LOCAL, 0, 4), Opt(OptOps.LOCAL, 1, 4)])
|
||||
except KernelOptError:
|
||||
raise unittest.SkipTest("no locals")
|
||||
@@ -210,7 +211,7 @@ class TestStatsOptimized(unittest.TestCase):
|
||||
|
||||
def test_gemm_group(self):
|
||||
try:
|
||||
p = get_program(self.ast_gemm, opts=[Opt(OptOps.GROUP, 0, 4)])
|
||||
p = get_program(self.ast_gemm, renderer=Device[Device.DEFAULT].renderer, opts=[Opt(OptOps.GROUP, 0, 4)])
|
||||
except KernelOptError:
|
||||
raise unittest.SkipTest("no locals")
|
||||
SZ = N*N*4
|
||||
@@ -219,14 +220,14 @@ class TestStatsOptimized(unittest.TestCase):
|
||||
self.assertEqual(p.estimates.lds, 2*N*N*N*4 + SZ*4 + (SZ*4 + 4*N*N)*4)
|
||||
|
||||
def test_reduce(self):
|
||||
p = get_program(self.ast_reduce, opts=[])
|
||||
p = get_program(self.ast_reduce, renderer=Device[Device.DEFAULT].renderer, opts=[])
|
||||
print(p.name, p.estimates.ops, p.estimates.mem, p.estimates.lds)
|
||||
self.assertEqual(p.estimates.ops, N*N)
|
||||
self.assertEqual(p.estimates.mem, N*N*4 + 4)
|
||||
|
||||
def test_reduce_group(self):
|
||||
try:
|
||||
p = get_program(self.ast_reduce, opts=[Opt(OptOps.GROUP, 0, 50)])
|
||||
p = get_program(self.ast_reduce, renderer=Device[Device.DEFAULT].renderer, opts=[Opt(OptOps.GROUP, 0, 50)])
|
||||
except KernelOptError:
|
||||
raise unittest.SkipTest("no locals")
|
||||
# NOTE: these are wrong, they don't respect the if statement
|
||||
|
||||
@@ -14,7 +14,7 @@ from tinygrad.codegen.opt import Opt
|
||||
# **************** Program Creation ****************
|
||||
|
||||
@track_rewrites(name=lambda *args,ret,**kwargs: TracingKey(ret.name, (ret.function_name, ret.ast), ret=ret), replay=True)
|
||||
def get_program(ast:UOp, renderer:Renderer|None=None, opts:list[Opt]|None=None) -> ProgramSpec:
|
||||
def get_program(ast:UOp, renderer:Renderer, opts:list[Opt]|None=None) -> ProgramSpec:
|
||||
"""
|
||||
Transform an AST into a ProgramSpec. May trigger BEAM search.
|
||||
|
||||
@@ -30,7 +30,6 @@ def get_program(ast:UOp, renderer:Renderer|None=None, opts:list[Opt]|None=None)
|
||||
if DEBUG >= 5: print(pyrender(ast))
|
||||
|
||||
# linearize
|
||||
if renderer is None: renderer = Device.default.renderer
|
||||
if opts is not None:
|
||||
assert ast.arg is None, "can't apply opts if sink has an arg"
|
||||
ast = ast.replace(arg=KernelInfo(opts_to_apply=tuple(opts)))
|
||||
|
||||
Reference in New Issue
Block a user