enable using assembly in Tensor.custom_kernel (#13895)

This commit is contained in:
qazal
2026-01-02 00:12:01 +09:00
committed by GitHub
parent c0f52c9dcb
commit 9726500de8
3 changed files with 16 additions and 13 deletions

View File

@@ -3,18 +3,21 @@
import numpy as np
import unittest
import subprocess, struct, math, textwrap
import subprocess, struct, math, textwrap, functools
from tinygrad import Tensor, dtypes, Device, UOp
from tinygrad.uop.ops import Ops
from tinygrad.uop.ops import Ops, KernelInfo
from tinygrad.helpers import getenv
from tinygrad.runtime.support.compiler_amd import amdgpu_disassemble
from tinygrad.renderer import ProgramSpec
from tinygrad.engine.realize import CompiledRunner
from extra.assembly.amd.autogen.rdna3.ins import *
from extra.assembly.amd.asm import waitcnt
from test.testextra.test_cfg_viz import template
def custom_src(out:UOp, src:str, device:str, n_threads:int=1, n_workgroups:int=1) -> UOp:
lidx = UOp.special(n_threads, "lidx0")
gidx = UOp.special(n_workgroups, "gidx0")
sink = UOp.sink(out, lidx, gidx, arg=KernelInfo(name="test"))
return UOp(Ops.PROGRAM, src=(sink, UOp(Ops.DEVICE, arg=device), UOp(Ops.LINEAR, src=(*sink.src, sink)), UOp(Ops.SOURCE, arg=src)))
def get_output(asm:list, n_threads:int=1, vdst:VGPR=v[1]):
out = Tensor([0]*n_threads, dtype=dtypes.uint32).realize()
src = "\n".join(inst.disasm() for inst in [
@@ -26,11 +29,9 @@ def get_output(asm:list, n_threads:int=1, vdst:VGPR=v[1]):
global_store_b32(addr=v[0], data=vdst, saddr=s[0:1]),
s_endpgm()
])
prg = ProgramSpec("test", template.replace("fn_name", "test").replace("INSTRUCTION", textwrap.dedent(src)), Device.DEFAULT, UOp(Ops.SINK),
global_size=[1, 1, 1], local_size=[n_threads, 1, 1], globals=[0])
car = CompiledRunner(prg)
if getenv("PRINT_ASM"): amdgpu_disassemble(car.lib)
car([out.uop.buffer], {}, wait=True)
src = template.replace("fn_name", "test").replace("INSTRUCTION", textwrap.dedent(src))
out = Tensor.custom_kernel(out, fxn=functools.partial(custom_src, src=src, device=out.device, n_threads=n_threads))[0]
out.realize()
return out.tolist()
def f16_to_bits(x:float) -> int: return struct.unpack('<H', struct.pack('<e', x))[0]