mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-08 22:48:25 -05:00
enable using assembly in Tensor.custom_kernel (#13895)
This commit is contained in:
@@ -3,18 +3,21 @@
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import unittest
|
import unittest
|
||||||
import subprocess, struct, math, textwrap
|
import subprocess, struct, math, textwrap, functools
|
||||||
from tinygrad import Tensor, dtypes, Device, UOp
|
from tinygrad import Tensor, dtypes, Device, UOp
|
||||||
from tinygrad.uop.ops import Ops
|
from tinygrad.uop.ops import Ops, KernelInfo
|
||||||
from tinygrad.helpers import getenv
|
from tinygrad.helpers import getenv
|
||||||
from tinygrad.runtime.support.compiler_amd import amdgpu_disassemble
|
|
||||||
from tinygrad.renderer import ProgramSpec
|
|
||||||
from tinygrad.engine.realize import CompiledRunner
|
|
||||||
|
|
||||||
from extra.assembly.amd.autogen.rdna3.ins import *
|
from extra.assembly.amd.autogen.rdna3.ins import *
|
||||||
from extra.assembly.amd.asm import waitcnt
|
from extra.assembly.amd.asm import waitcnt
|
||||||
from test.testextra.test_cfg_viz import template
|
from test.testextra.test_cfg_viz import template
|
||||||
|
|
||||||
|
def custom_src(out:UOp, src:str, device:str, n_threads:int=1, n_workgroups:int=1) -> UOp:
|
||||||
|
lidx = UOp.special(n_threads, "lidx0")
|
||||||
|
gidx = UOp.special(n_workgroups, "gidx0")
|
||||||
|
sink = UOp.sink(out, lidx, gidx, arg=KernelInfo(name="test"))
|
||||||
|
return UOp(Ops.PROGRAM, src=(sink, UOp(Ops.DEVICE, arg=device), UOp(Ops.LINEAR, src=(*sink.src, sink)), UOp(Ops.SOURCE, arg=src)))
|
||||||
|
|
||||||
def get_output(asm:list, n_threads:int=1, vdst:VGPR=v[1]):
|
def get_output(asm:list, n_threads:int=1, vdst:VGPR=v[1]):
|
||||||
out = Tensor([0]*n_threads, dtype=dtypes.uint32).realize()
|
out = Tensor([0]*n_threads, dtype=dtypes.uint32).realize()
|
||||||
src = "\n".join(inst.disasm() for inst in [
|
src = "\n".join(inst.disasm() for inst in [
|
||||||
@@ -26,11 +29,9 @@ def get_output(asm:list, n_threads:int=1, vdst:VGPR=v[1]):
|
|||||||
global_store_b32(addr=v[0], data=vdst, saddr=s[0:1]),
|
global_store_b32(addr=v[0], data=vdst, saddr=s[0:1]),
|
||||||
s_endpgm()
|
s_endpgm()
|
||||||
])
|
])
|
||||||
prg = ProgramSpec("test", template.replace("fn_name", "test").replace("INSTRUCTION", textwrap.dedent(src)), Device.DEFAULT, UOp(Ops.SINK),
|
src = template.replace("fn_name", "test").replace("INSTRUCTION", textwrap.dedent(src))
|
||||||
global_size=[1, 1, 1], local_size=[n_threads, 1, 1], globals=[0])
|
out = Tensor.custom_kernel(out, fxn=functools.partial(custom_src, src=src, device=out.device, n_threads=n_threads))[0]
|
||||||
car = CompiledRunner(prg)
|
out.realize()
|
||||||
if getenv("PRINT_ASM"): amdgpu_disassemble(car.lib)
|
|
||||||
car([out.uop.buffer], {}, wait=True)
|
|
||||||
return out.tolist()
|
return out.tolist()
|
||||||
|
|
||||||
def f16_to_bits(x:float) -> int: return struct.unpack('<H', struct.pack('<e', x))[0]
|
def f16_to_bits(x:float) -> int: return struct.unpack('<H', struct.pack('<e', x))[0]
|
||||||
|
|||||||
@@ -166,8 +166,10 @@ def get_program(ast:UOp, renderer:Renderer, opts:list[Opt]|None=None) -> Program
|
|||||||
if ast.arg is None: ast = ast.replace(arg=KernelInfo())
|
if ast.arg is None: ast = ast.replace(arg=KernelInfo())
|
||||||
|
|
||||||
# rewrite to prg
|
# rewrite to prg
|
||||||
full_sink = full_rewrite_to_sink(ast, renderer, optimize=ast.tag is None)
|
if ast.op is Ops.PROGRAM: prg = ast
|
||||||
prg = UOp(Ops.PROGRAM, src=(full_sink, UOp(Ops.DEVICE, arg=renderer.device)))
|
else:
|
||||||
|
full_sink = full_rewrite_to_sink(ast, renderer, optimize=ast.tag is None)
|
||||||
|
prg = UOp(Ops.PROGRAM, src=(full_sink, UOp(Ops.DEVICE, arg=renderer.device)))
|
||||||
prg = graph_rewrite(prg, pm_to_program, ctx=renderer, name="linearize/render")
|
prg = graph_rewrite(prg, pm_to_program, ctx=renderer, name="linearize/render")
|
||||||
|
|
||||||
# create the ProgramSpec
|
# create the ProgramSpec
|
||||||
|
|||||||
@@ -125,7 +125,7 @@ def get_runner(device:str, ast:UOp) -> CompiledRunner:
|
|||||||
|
|
||||||
# NOTE: ctx is the buffers
|
# NOTE: ctx is the buffers
|
||||||
si_lowerer = PatternMatcher([
|
si_lowerer = PatternMatcher([
|
||||||
(UPat(Ops.SINK, name="sink"), lambda ctx,sink: get_runner(ctx[0].device, sink)),
|
(UPat((Ops.SINK, Ops.PROGRAM), name="sink"), lambda ctx,sink: get_runner(ctx[0].device, sink)),
|
||||||
(UPat(Ops.BUFFER_VIEW), lambda ctx: ViewOp(ctx[0])),
|
(UPat(Ops.BUFFER_VIEW), lambda ctx: ViewOp(ctx[0])),
|
||||||
(UPat(Ops.COPY, name="copy"), lambda ctx,copy: (BufferXfer(ctx[0].nbytes, ctx[0].device, ctx[1].device) \
|
(UPat(Ops.COPY, name="copy"), lambda ctx,copy: (BufferXfer(ctx[0].nbytes, ctx[0].device, ctx[1].device) \
|
||||||
if hasattr(Device[ctx[0].device].allocator, '_transfer') and all_same([x.device.split(":")[0] for x in ctx]) \
|
if hasattr(Device[ctx[0].device].allocator, '_transfer') and all_same([x.device.split(":")[0] for x in ctx]) \
|
||||||
|
|||||||
Reference in New Issue
Block a user