mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 07:28:15 -05:00
rewrites for renderer and compiler (#13646)
* rewrites for renderer and compiler * full_rewrite_to_program * fix pre-commit * compiler passed into get_program * no pkl compiler * lib on program spec * fix spec * fix test * no device * compiler_device * nm * fix nir * fix * simplest * fix tests * revert
This commit is contained in:
@@ -26,14 +26,14 @@ def helper_tc_ensure_uops_and_opts_count(N: int, M:int, K:int, dtype_in:DType, d
|
||||
opts_to_apply = [Opt(OptOps.TC, axis, (tc_select, tc_opt, 1))]
|
||||
|
||||
if ensure_triggered:
|
||||
program = get_program(realized_ast, Device[Device.DEFAULT].renderer, opts=opts_to_apply)
|
||||
program = get_program(realized_ast, Device[Device.DEFAULT].renderer, Device.DEFAULT, opts=opts_to_apply)
|
||||
wmmas = len([uop for uop in program.uops if uop.op is Ops.WMMA])
|
||||
tcs = len([x for x in program.applied_opts if x.op is OptOps.TC])
|
||||
assert wmmas > 0, "tensor core not triggered"
|
||||
assert tcs == 1, "tensor core opt not included"
|
||||
else:
|
||||
try:
|
||||
program = get_program(realized_ast, Device[Device.DEFAULT].renderer, opts=opts_to_apply)
|
||||
program = get_program(realized_ast, Device[Device.DEFAULT].renderer, Device.DEFAULT, opts=opts_to_apply)
|
||||
assert False, "OptOps.TC triggered, expected KernelOptError"
|
||||
except KernelOptError: pass
|
||||
|
||||
@@ -44,7 +44,7 @@ def helper_tc_allclose(N:int, M:int, K:int, dtype_in:DType, dtype_out:DType, axi
|
||||
if dtype_in == dtypes.bfloat16: r = r.float()
|
||||
realized_ast, bufs = helper_realized_ast(r)
|
||||
opts = [Opt(op=OptOps.TC, axis=axis, arg=(tc_select, tc_opt, use_tensor_cores))]
|
||||
prg = CompiledRunner(replace(get_program(realized_ast, Device[Device.DEFAULT].renderer, opts=opts), device=Device.DEFAULT))
|
||||
prg = CompiledRunner(replace(get_program(realized_ast, Device[Device.DEFAULT].renderer, Device.DEFAULT, opts=opts), device=Device.DEFAULT))
|
||||
if use_tensor_cores == 1: assert len([uop for uop in prg.p.uops if uop.op is Ops.WMMA]) > 0, "wmma not triggered"
|
||||
assert len([x for x in prg.p.uops[-1].arg.applied_opts if x.op is OptOps.TC]) == 1, "tensor core opt not included"
|
||||
prg.exec(bufs)
|
||||
|
||||
Reference in New Issue
Block a user