rewrites for renderer and compiler (#13646)

* rewrites for renderer and compiler

* full_rewrite_to_program

* fix pre-commit

* compiler passed into get_program

* no pkl compiler

* lib on program spec

* fix spec

* fix test

* no device

* compiler_device

* nm

* fix nir

* fix

* simplest

* fix tests

* revert
This commit is contained in:
George Hotz
2025-12-22 18:58:43 -05:00
committed by GitHub
parent 4edaaf19e5
commit 339dadf056
8 changed files with 66 additions and 30 deletions

View File

@@ -26,14 +26,14 @@ def helper_tc_ensure_uops_and_opts_count(N: int, M:int, K:int, dtype_in:DType, d
opts_to_apply = [Opt(OptOps.TC, axis, (tc_select, tc_opt, 1))]
if ensure_triggered:
program = get_program(realized_ast, Device[Device.DEFAULT].renderer, opts=opts_to_apply)
program = get_program(realized_ast, Device[Device.DEFAULT].renderer, Device.DEFAULT, opts=opts_to_apply)
wmmas = len([uop for uop in program.uops if uop.op is Ops.WMMA])
tcs = len([x for x in program.applied_opts if x.op is OptOps.TC])
assert wmmas > 0, "tensor core not triggered"
assert tcs == 1, "tensor core opt not included"
else:
try:
program = get_program(realized_ast, Device[Device.DEFAULT].renderer, opts=opts_to_apply)
program = get_program(realized_ast, Device[Device.DEFAULT].renderer, Device.DEFAULT, opts=opts_to_apply)
assert False, "OptOps.TC triggered, expected KernelOptError"
except KernelOptError: pass
@@ -44,7 +44,7 @@ def helper_tc_allclose(N:int, M:int, K:int, dtype_in:DType, dtype_out:DType, axi
if dtype_in == dtypes.bfloat16: r = r.float()
realized_ast, bufs = helper_realized_ast(r)
opts = [Opt(op=OptOps.TC, axis=axis, arg=(tc_select, tc_opt, use_tensor_cores))]
prg = CompiledRunner(replace(get_program(realized_ast, Device[Device.DEFAULT].renderer, opts=opts), device=Device.DEFAULT))
prg = CompiledRunner(replace(get_program(realized_ast, Device[Device.DEFAULT].renderer, Device.DEFAULT, opts=opts), device=Device.DEFAULT))
if use_tensor_cores == 1: assert len([uop for uop in prg.p.uops if uop.op is Ops.WMMA]) > 0, "wmma not triggered"
assert len([x for x in prg.p.uops[-1].arg.applied_opts if x.op is OptOps.TC]) == 1, "tensor core opt not included"
prg.exec(bufs)