mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-07 22:23:55 -05:00
enable using assembly in Tensor.custom_kernel (#13895)
This commit is contained in:
@@ -3,18 +3,21 @@
|
||||
|
||||
import numpy as np
|
||||
import unittest
|
||||
import subprocess, struct, math, textwrap
|
||||
import subprocess, struct, math, textwrap, functools
|
||||
from tinygrad import Tensor, dtypes, Device, UOp
|
||||
from tinygrad.uop.ops import Ops
|
||||
from tinygrad.uop.ops import Ops, KernelInfo
|
||||
from tinygrad.helpers import getenv
|
||||
from tinygrad.runtime.support.compiler_amd import amdgpu_disassemble
|
||||
from tinygrad.renderer import ProgramSpec
|
||||
from tinygrad.engine.realize import CompiledRunner
|
||||
|
||||
from extra.assembly.amd.autogen.rdna3.ins import *
|
||||
from extra.assembly.amd.asm import waitcnt
|
||||
from test.testextra.test_cfg_viz import template
|
||||
|
||||
def custom_src(out:UOp, src:str, device:str, n_threads:int=1, n_workgroups:int=1) -> UOp:
|
||||
lidx = UOp.special(n_threads, "lidx0")
|
||||
gidx = UOp.special(n_workgroups, "gidx0")
|
||||
sink = UOp.sink(out, lidx, gidx, arg=KernelInfo(name="test"))
|
||||
return UOp(Ops.PROGRAM, src=(sink, UOp(Ops.DEVICE, arg=device), UOp(Ops.LINEAR, src=(*sink.src, sink)), UOp(Ops.SOURCE, arg=src)))
|
||||
|
||||
def get_output(asm:list, n_threads:int=1, vdst:VGPR=v[1]):
|
||||
out = Tensor([0]*n_threads, dtype=dtypes.uint32).realize()
|
||||
src = "\n".join(inst.disasm() for inst in [
|
||||
@@ -26,11 +29,9 @@ def get_output(asm:list, n_threads:int=1, vdst:VGPR=v[1]):
|
||||
global_store_b32(addr=v[0], data=vdst, saddr=s[0:1]),
|
||||
s_endpgm()
|
||||
])
|
||||
prg = ProgramSpec("test", template.replace("fn_name", "test").replace("INSTRUCTION", textwrap.dedent(src)), Device.DEFAULT, UOp(Ops.SINK),
|
||||
global_size=[1, 1, 1], local_size=[n_threads, 1, 1], globals=[0])
|
||||
car = CompiledRunner(prg)
|
||||
if getenv("PRINT_ASM"): amdgpu_disassemble(car.lib)
|
||||
car([out.uop.buffer], {}, wait=True)
|
||||
src = template.replace("fn_name", "test").replace("INSTRUCTION", textwrap.dedent(src))
|
||||
out = Tensor.custom_kernel(out, fxn=functools.partial(custom_src, src=src, device=out.device, n_threads=n_threads))[0]
|
||||
out.realize()
|
||||
return out.tolist()
|
||||
|
||||
def f16_to_bits(x:float) -> int: return struct.unpack('<H', struct.pack('<e', x))[0]
|
||||
|
||||
Reference in New Issue
Block a user