mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-15 01:48:23 -05:00
optional Estimates in KernelInfo (#14147)
* optional Estimates in KernelInfo * custom asm test plumbing * s_code_end * estimates test * vaddr arg in global_store * kernel desc * Ops.DEVICE name
This commit is contained in:
48
extra/assembly/amd/test/test_custom_kernel.py
Normal file
48
extra/assembly/amd/test/test_custom_kernel.py
Normal file
@@ -0,0 +1,48 @@
|
|||||||
|
import unittest
|
||||||
|
import functools
|
||||||
|
from tinygrad import Tensor, Device, dtypes
|
||||||
|
from tinygrad.uop.ops import UOp, Ops, KernelInfo
|
||||||
|
from tinygrad.renderer import Estimates
|
||||||
|
from tinygrad.runtime.support.compiler_amd import HIPCompiler
|
||||||
|
|
||||||
|
from extra.assembly.amd.autogen.rdna3.ins import *
|
||||||
|
from extra.assembly.amd.dsl import s, v, Inst
|
||||||
|
|
||||||
|
def assemble_insts(insts:list[Inst], name:str, arch:str, kernarg_size:int=8) -> tuple[UOp, UOp]:
|
||||||
|
kd = {"kernarg_size":kernarg_size, "user_sgpr_kernarg_segment_ptr":1, "next_free_vgpr":8, "next_free_sgpr":8, "wavefront_size32":1}
|
||||||
|
disasm = "\n".join([inst.disasm() for inst in insts])
|
||||||
|
hsasrc = f".text\n.globl {name}\n.p2align 8\n.type fn_name,@function\n{name}:\n{disasm}\ns_code_end\n"
|
||||||
|
hsasrc += f".rodata\n.p2align 6\n.amdhsa_kernel {name}\n"+"\n".join([f".amdhsa_{k} {v}" for k,v in kd.items()])+"\n.end_amdhsa_kernel"
|
||||||
|
binary = HIPCompiler(arch).compile(hsasrc)
|
||||||
|
return UOp(Ops.SOURCE, arg=disasm), UOp(Ops.BINARY, arg=binary)
|
||||||
|
|
||||||
|
def custom_add_one(A:UOp, arch:str) -> UOp:
|
||||||
|
A = A.flatten()
|
||||||
|
assert dtypes.is_float(A.dtype.base), f"buffer dtype must be float32, got {A.dtype}"
|
||||||
|
threads = UOp.special(A.size, "lidx0")
|
||||||
|
insts = [
|
||||||
|
s_load_b64(s[0:1], s[0:1], soffset=NULL),
|
||||||
|
s_waitcnt(lgkmcnt=0),
|
||||||
|
v_lshlrev_b32_e32(v[0], 2, v[0]), # element offset
|
||||||
|
global_load_b32(v[1], v[0], saddr=s[0:1]),
|
||||||
|
s_waitcnt(vmcnt=0),
|
||||||
|
v_mov_b32_e32(v[2], 1.0),
|
||||||
|
v_add_f32_e32(v[1], v[1], v[2]),
|
||||||
|
global_store_b32(addr=v[0], data=v[1], saddr=s[0:1]),
|
||||||
|
s_endpgm(),
|
||||||
|
]
|
||||||
|
sink = UOp.sink(A.base, threads, arg=KernelInfo(name:=f"custom_add_one_{A.size}", estimates=Estimates(ops=A.size, mem=A.size*4*2)))
|
||||||
|
return UOp(Ops.PROGRAM, src=(sink, UOp(Ops.DEVICE, arg="AMD"), UOp(Ops.LINEAR, src=(*sink.src, sink)), *assemble_insts(insts, name, arch)), arg=())
|
||||||
|
|
||||||
|
class TestCustomKernel(unittest.TestCase):
|
||||||
|
def test_simple(self):
|
||||||
|
a = Tensor.full((16, 16), 1.).contiguous().realize()
|
||||||
|
a = Tensor.custom_kernel(a, fxn=functools.partial(custom_add_one, arch=Device[Device.DEFAULT].arch))[0]
|
||||||
|
ei = a.schedule()[-1].lower()
|
||||||
|
self.assertEqual(ei.prg.estimates.ops, a.numel())
|
||||||
|
self.assertEqual(ei.prg.estimates.mem, a.nbytes()*2)
|
||||||
|
ei.run()
|
||||||
|
self.assertTrue((a.numpy() == 2.).all())
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
@@ -608,7 +608,8 @@ def test_matmul():
|
|||||||
def asm_kernel(A:UOp, B:UOp, C:UOp) -> UOp:
|
def asm_kernel(A:UOp, B:UOp, C:UOp) -> UOp:
|
||||||
gidxs = [UOp.special(n, f"gidx{i}") for i,n in enumerate(grid)]
|
gidxs = [UOp.special(n, f"gidx{i}") for i,n in enumerate(grid)]
|
||||||
lidxs = [UOp.special(n, f"lidx{i}") for i,n in enumerate(local)]
|
lidxs = [UOp.special(n, f"lidx{i}") for i,n in enumerate(local)]
|
||||||
sink = UOp.sink(A.base, B.base, C.base, *gidxs, *lidxs, arg=KernelInfo(name=colored("kernel", "cyan")))
|
sink = UOp.sink(A.base, B.base, C.base, *gidxs, *lidxs, arg=KernelInfo(name=colored("kernel", "cyan"),
|
||||||
|
estimates=Estimates(ops=N*N*N*2, mem=N*N*4*3)))
|
||||||
return UOp(Ops.PROGRAM, src=(sink, UOp(Ops.DEVICE, arg=dname), UOp(Ops.LINEAR, src=(*sink.src, sink)), UOp(Ops.SOURCE, arg=asm),
|
return UOp(Ops.PROGRAM, src=(sink, UOp(Ops.DEVICE, arg=dname), UOp(Ops.LINEAR, src=(*sink.src, sink)), UOp(Ops.SOURCE, arg=asm),
|
||||||
UOp(Ops.BINARY, arg=binary)), arg=())
|
UOp(Ops.BINARY, arg=binary)), arg=())
|
||||||
c = Tensor.custom_kernel(a, b, c, fxn=asm_kernel)[2]
|
c = Tensor.custom_kernel(a, b, c, fxn=asm_kernel)[2]
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ from typing import Callable, cast, TYPE_CHECKING
|
|||||||
import functools
|
import functools
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from tinygrad.helpers import to_function_name, dedup, prod, DEBUG
|
from tinygrad.helpers import to_function_name, dedup, prod, DEBUG
|
||||||
from tinygrad.uop.ops import Ops, UOp, sym_infer, sint, Variable, ssimplify, GroupOp, PatternMatcher, print_uops
|
from tinygrad.uop.ops import Ops, UOp, sym_infer, sint, Variable, ssimplify, GroupOp, PatternMatcher, print_uops, KernelInfo
|
||||||
from tinygrad.dtype import AddrSpace, PtrDType
|
from tinygrad.dtype import AddrSpace, PtrDType
|
||||||
from tinygrad.codegen.opt.tc import TensorCore
|
from tinygrad.codegen.opt.tc import TensorCore
|
||||||
from tinygrad.codegen.opt import Opt
|
from tinygrad.codegen.opt import Opt
|
||||||
@@ -38,6 +38,7 @@ class Estimates:
|
|||||||
elif u.op is Ops.IF:
|
elif u.op is Ops.IF:
|
||||||
dont_count = dont_count.union(u.src[0].toposort())
|
dont_count = dont_count.union(u.src[0].toposort())
|
||||||
for u in uops:
|
for u in uops:
|
||||||
|
if u.op is Ops.SINK and isinstance(u.arg, KernelInfo) and u.arg.estimates is not None: return u.arg.estimates
|
||||||
if u.op in {Ops.LOAD, Ops.STORE}:
|
if u.op in {Ops.LOAD, Ops.STORE}:
|
||||||
buf = u
|
buf = u
|
||||||
while len(buf.src): buf = buf.src[0]
|
while len(buf.src): buf = buf.src[0]
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ from tinygrad.helpers import PROFILE, dedup, cdiv, cmod, diskcache_put, to_funct
|
|||||||
from tinygrad.helpers import strip_parens, colored, ansilen, printable, panic
|
from tinygrad.helpers import strip_parens, colored, ansilen, printable, panic
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from tinygrad.device import Buffer, MultiBuffer
|
from tinygrad.device import Buffer, MultiBuffer
|
||||||
|
from tinygrad.renderer import Estimates
|
||||||
|
|
||||||
class AxisType(Enum):
|
class AxisType(Enum):
|
||||||
def __repr__(self): return str(self)
|
def __repr__(self): return str(self)
|
||||||
@@ -828,6 +829,7 @@ class KernelInfo:
|
|||||||
dont_use_locals: bool = False # don't use local indexing
|
dont_use_locals: bool = False # don't use local indexing
|
||||||
applied_opts: tuple = tuple()
|
applied_opts: tuple = tuple()
|
||||||
opts_to_apply: tuple|None = None
|
opts_to_apply: tuple|None = None
|
||||||
|
estimates: Estimates|None = None
|
||||||
@property
|
@property
|
||||||
def function_name(self): return to_function_name(self.name)
|
def function_name(self): return to_function_name(self.name)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user