mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-07 22:23:55 -05:00
enable using assembly in Tensor.custom_kernel (#13895)
This commit is contained in:
@@ -3,18 +3,21 @@
|
||||
|
||||
import numpy as np
|
||||
import unittest
|
||||
import subprocess, struct, math, textwrap
|
||||
import subprocess, struct, math, textwrap, functools
|
||||
from tinygrad import Tensor, dtypes, Device, UOp
|
||||
from tinygrad.uop.ops import Ops
|
||||
from tinygrad.uop.ops import Ops, KernelInfo
|
||||
from tinygrad.helpers import getenv
|
||||
from tinygrad.runtime.support.compiler_amd import amdgpu_disassemble
|
||||
from tinygrad.renderer import ProgramSpec
|
||||
from tinygrad.engine.realize import CompiledRunner
|
||||
|
||||
from extra.assembly.amd.autogen.rdna3.ins import *
|
||||
from extra.assembly.amd.asm import waitcnt
|
||||
from test.testextra.test_cfg_viz import template
|
||||
|
||||
def custom_src(out:UOp, src:str, device:str, n_threads:int=1, n_workgroups:int=1) -> UOp:
|
||||
lidx = UOp.special(n_threads, "lidx0")
|
||||
gidx = UOp.special(n_workgroups, "gidx0")
|
||||
sink = UOp.sink(out, lidx, gidx, arg=KernelInfo(name="test"))
|
||||
return UOp(Ops.PROGRAM, src=(sink, UOp(Ops.DEVICE, arg=device), UOp(Ops.LINEAR, src=(*sink.src, sink)), UOp(Ops.SOURCE, arg=src)))
|
||||
|
||||
def get_output(asm:list, n_threads:int=1, vdst:VGPR=v[1]):
|
||||
out = Tensor([0]*n_threads, dtype=dtypes.uint32).realize()
|
||||
src = "\n".join(inst.disasm() for inst in [
|
||||
@@ -26,11 +29,9 @@ def get_output(asm:list, n_threads:int=1, vdst:VGPR=v[1]):
|
||||
global_store_b32(addr=v[0], data=vdst, saddr=s[0:1]),
|
||||
s_endpgm()
|
||||
])
|
||||
prg = ProgramSpec("test", template.replace("fn_name", "test").replace("INSTRUCTION", textwrap.dedent(src)), Device.DEFAULT, UOp(Ops.SINK),
|
||||
global_size=[1, 1, 1], local_size=[n_threads, 1, 1], globals=[0])
|
||||
car = CompiledRunner(prg)
|
||||
if getenv("PRINT_ASM"): amdgpu_disassemble(car.lib)
|
||||
car([out.uop.buffer], {}, wait=True)
|
||||
src = template.replace("fn_name", "test").replace("INSTRUCTION", textwrap.dedent(src))
|
||||
out = Tensor.custom_kernel(out, fxn=functools.partial(custom_src, src=src, device=out.device, n_threads=n_threads))[0]
|
||||
out.realize()
|
||||
return out.tolist()
|
||||
|
||||
def f16_to_bits(x:float) -> int: return struct.unpack('<H', struct.pack('<e', x))[0]
|
||||
|
||||
@@ -166,8 +166,10 @@ def get_program(ast:UOp, renderer:Renderer, opts:list[Opt]|None=None) -> Program
|
||||
if ast.arg is None: ast = ast.replace(arg=KernelInfo())
|
||||
|
||||
# rewrite to prg
|
||||
full_sink = full_rewrite_to_sink(ast, renderer, optimize=ast.tag is None)
|
||||
prg = UOp(Ops.PROGRAM, src=(full_sink, UOp(Ops.DEVICE, arg=renderer.device)))
|
||||
if ast.op is Ops.PROGRAM: prg = ast
|
||||
else:
|
||||
full_sink = full_rewrite_to_sink(ast, renderer, optimize=ast.tag is None)
|
||||
prg = UOp(Ops.PROGRAM, src=(full_sink, UOp(Ops.DEVICE, arg=renderer.device)))
|
||||
prg = graph_rewrite(prg, pm_to_program, ctx=renderer, name="linearize/render")
|
||||
|
||||
# create the ProgramSpec
|
||||
|
||||
@@ -125,7 +125,7 @@ def get_runner(device:str, ast:UOp) -> CompiledRunner:
|
||||
|
||||
# NOTE: ctx is the buffers
|
||||
si_lowerer = PatternMatcher([
|
||||
(UPat(Ops.SINK, name="sink"), lambda ctx,sink: get_runner(ctx[0].device, sink)),
|
||||
(UPat((Ops.SINK, Ops.PROGRAM), name="sink"), lambda ctx,sink: get_runner(ctx[0].device, sink)),
|
||||
(UPat(Ops.BUFFER_VIEW), lambda ctx: ViewOp(ctx[0])),
|
||||
(UPat(Ops.COPY, name="copy"), lambda ctx,copy: (BufferXfer(ctx[0].nbytes, ctx[0].device, ctx[1].device) \
|
||||
if hasattr(Device[ctx[0].device].allocator, '_transfer') and all_same([x.device.split(":")[0] for x in ctx]) \
|
||||
|
||||
Reference in New Issue
Block a user