enable using assembly in Tensor.custom_kernel (#13895)

This commit is contained in:
qazal
2026-01-02 00:12:01 +09:00
committed by GitHub
parent c0f52c9dcb
commit 9726500de8
3 changed files with 16 additions and 13 deletions

View File

@@ -3,18 +3,21 @@
import numpy as np import numpy as np
import unittest import unittest
import subprocess, struct, math, textwrap import subprocess, struct, math, textwrap, functools
from tinygrad import Tensor, dtypes, Device, UOp from tinygrad import Tensor, dtypes, Device, UOp
from tinygrad.uop.ops import Ops from tinygrad.uop.ops import Ops, KernelInfo
from tinygrad.helpers import getenv from tinygrad.helpers import getenv
from tinygrad.runtime.support.compiler_amd import amdgpu_disassemble
from tinygrad.renderer import ProgramSpec
from tinygrad.engine.realize import CompiledRunner
from extra.assembly.amd.autogen.rdna3.ins import * from extra.assembly.amd.autogen.rdna3.ins import *
from extra.assembly.amd.asm import waitcnt from extra.assembly.amd.asm import waitcnt
from test.testextra.test_cfg_viz import template from test.testextra.test_cfg_viz import template
def custom_src(out:UOp, src:str, device:str, n_threads:int=1, n_workgroups:int=1) -> UOp:
lidx = UOp.special(n_threads, "lidx0")
gidx = UOp.special(n_workgroups, "gidx0")
sink = UOp.sink(out, lidx, gidx, arg=KernelInfo(name="test"))
return UOp(Ops.PROGRAM, src=(sink, UOp(Ops.DEVICE, arg=device), UOp(Ops.LINEAR, src=(*sink.src, sink)), UOp(Ops.SOURCE, arg=src)))
def get_output(asm:list, n_threads:int=1, vdst:VGPR=v[1]): def get_output(asm:list, n_threads:int=1, vdst:VGPR=v[1]):
out = Tensor([0]*n_threads, dtype=dtypes.uint32).realize() out = Tensor([0]*n_threads, dtype=dtypes.uint32).realize()
src = "\n".join(inst.disasm() for inst in [ src = "\n".join(inst.disasm() for inst in [
@@ -26,11 +29,9 @@ def get_output(asm:list, n_threads:int=1, vdst:VGPR=v[1]):
global_store_b32(addr=v[0], data=vdst, saddr=s[0:1]), global_store_b32(addr=v[0], data=vdst, saddr=s[0:1]),
s_endpgm() s_endpgm()
]) ])
prg = ProgramSpec("test", template.replace("fn_name", "test").replace("INSTRUCTION", textwrap.dedent(src)), Device.DEFAULT, UOp(Ops.SINK), src = template.replace("fn_name", "test").replace("INSTRUCTION", textwrap.dedent(src))
global_size=[1, 1, 1], local_size=[n_threads, 1, 1], globals=[0]) out = Tensor.custom_kernel(out, fxn=functools.partial(custom_src, src=src, device=out.device, n_threads=n_threads))[0]
car = CompiledRunner(prg) out.realize()
if getenv("PRINT_ASM"): amdgpu_disassemble(car.lib)
car([out.uop.buffer], {}, wait=True)
return out.tolist() return out.tolist()
def f16_to_bits(x:float) -> int: return struct.unpack('<H', struct.pack('<e', x))[0] def f16_to_bits(x:float) -> int: return struct.unpack('<H', struct.pack('<e', x))[0]

View File

@@ -166,8 +166,10 @@ def get_program(ast:UOp, renderer:Renderer, opts:list[Opt]|None=None) -> Program
if ast.arg is None: ast = ast.replace(arg=KernelInfo()) if ast.arg is None: ast = ast.replace(arg=KernelInfo())
# rewrite to prg # rewrite to prg
full_sink = full_rewrite_to_sink(ast, renderer, optimize=ast.tag is None) if ast.op is Ops.PROGRAM: prg = ast
prg = UOp(Ops.PROGRAM, src=(full_sink, UOp(Ops.DEVICE, arg=renderer.device))) else:
full_sink = full_rewrite_to_sink(ast, renderer, optimize=ast.tag is None)
prg = UOp(Ops.PROGRAM, src=(full_sink, UOp(Ops.DEVICE, arg=renderer.device)))
prg = graph_rewrite(prg, pm_to_program, ctx=renderer, name="linearize/render") prg = graph_rewrite(prg, pm_to_program, ctx=renderer, name="linearize/render")
# create the ProgramSpec # create the ProgramSpec

View File

@@ -125,7 +125,7 @@ def get_runner(device:str, ast:UOp) -> CompiledRunner:
# NOTE: ctx is the buffers # NOTE: ctx is the buffers
si_lowerer = PatternMatcher([ si_lowerer = PatternMatcher([
(UPat(Ops.SINK, name="sink"), lambda ctx,sink: get_runner(ctx[0].device, sink)), (UPat((Ops.SINK, Ops.PROGRAM), name="sink"), lambda ctx,sink: get_runner(ctx[0].device, sink)),
(UPat(Ops.BUFFER_VIEW), lambda ctx: ViewOp(ctx[0])), (UPat(Ops.BUFFER_VIEW), lambda ctx: ViewOp(ctx[0])),
(UPat(Ops.COPY, name="copy"), lambda ctx,copy: (BufferXfer(ctx[0].nbytes, ctx[0].device, ctx[1].device) \ (UPat(Ops.COPY, name="copy"), lambda ctx,copy: (BufferXfer(ctx[0].nbytes, ctx[0].device, ctx[1].device) \
if hasattr(Device[ctx[0].device].allocator, '_transfer') and all_same([x.device.split(":")[0] for x in ctx]) \ if hasattr(Device[ctx[0].device].allocator, '_transfer') and all_same([x.device.split(":")[0] for x in ctx]) \