mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
remove llvm requirement from amd (#14717)
* remove llvm requirement from amd * tests pass * test * sink kernarg_size * move stuff * amd_asm_matmul to new style * default type * fix tests, simpler * cu mode is faster and simpler * darken
This commit is contained in:
@@ -9,10 +9,10 @@
|
||||
# Accumulators: 128 vgprs (v[2-129])
|
||||
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
from tinygrad import Tensor, Device, Context, GlobalCounters
|
||||
from tinygrad.uop.ops import UOp, Ops, KernelInfo
|
||||
from tinygrad.helpers import getenv, colored
|
||||
from tinygrad.dtype import dtypes, AddrSpace
|
||||
from tinygrad.engine.realize import Estimates
|
||||
from tinygrad.renderer.amd.dsl import s, v, VCC_LO, NULL
|
||||
from tinygrad.runtime.autogen.amd.rdna3.ins import *
|
||||
@@ -51,14 +51,14 @@ V_B_TILE_REGS = [132, 136, 140, 144, 148, 152, 156, 160] # B tile: banks 0,0,0,
|
||||
# Named register assignments (SGPRs)
|
||||
# =============================================================================
|
||||
S_OUT_PTR = (0, 1) # output C matrix base pointer
|
||||
S_TILE_X = 2 # workgroup_x << 7
|
||||
S_TILE_Y = 3 # workgroup_y << 7
|
||||
S_WORKGROUP_X = 2 # workgroup_id_x (system SGPR, follows user SGPRs)
|
||||
S_WORKGROUP_Y = 3 # workgroup_id_y (system SGPR)
|
||||
S_DIM_N = 4 # matrix dimension N
|
||||
S_LOOP_BOUND = 7 # K-8 (loop termination bound)
|
||||
S_LOOP_CTR = 12 # loop counter (increments by 8)
|
||||
S_PREFETCH_FLAG = 13 # prefetch condition flag / row stride in epilogue
|
||||
S_WORKGROUP_X = 14 # workgroup_id_x
|
||||
S_WORKGROUP_Y = 15 # workgroup_id_y
|
||||
S_TILE_X = 14 # workgroup_x << 7
|
||||
S_TILE_Y = 15 # workgroup_y << 7
|
||||
# Kernarg load destinations
|
||||
S_KERNARG_A = (20, 21) # A pointer from kernarg
|
||||
S_KERNARG_B = (22, 23) # B pointer from kernarg
|
||||
@@ -183,48 +183,14 @@ class Kernel:
|
||||
waitcnt = (expcnt & 0x7) | ((lgkmcnt & 0x3f) << 4) | ((vmcnt & 0x3f) << 10)
|
||||
self.emit(s_waitcnt(simm16=waitcnt))
|
||||
|
||||
def to_asm(self):
|
||||
# Patch branch offsets: simm16 = (target_pos - branch_end_pos) / 4
|
||||
def finalize(self):
|
||||
"""Patch branch offsets and return the finalized instruction list."""
|
||||
for inst in self.instructions:
|
||||
if inst._target is None: continue
|
||||
offset_dwords = (self.labels[inst._target] - inst._pos - inst.size()) // 4
|
||||
if not -32768 <= offset_dwords <= 32767: raise ValueError(f"branch to '{inst._target}' offset {offset_dwords} exceeds simm16 range")
|
||||
inst.simm16 = offset_dwords
|
||||
|
||||
# TODO: replace this with direct ELF
|
||||
from test.amd.disasm import disasm
|
||||
body = ['\t' + disasm(inst) for inst in self.instructions]
|
||||
|
||||
# limit wave occupancy by using more LDS
|
||||
lds_size = max(LDS_SIZE, 65536//getenv("LIMIT_OCC", 65536))
|
||||
|
||||
# HSA kernel descriptor attributes (zeros included for compatibility)
|
||||
hsa = [
|
||||
('group_segment_fixed_size', lds_size), ('private_segment_fixed_size', 0), ('kernarg_size', 36),
|
||||
('user_sgpr_count', 14), ('user_sgpr_dispatch_ptr', 0), ('user_sgpr_queue_ptr', 0),
|
||||
('user_sgpr_kernarg_segment_ptr', 1), ('user_sgpr_dispatch_id', 0), ('user_sgpr_private_segment_size', 0),
|
||||
('wavefront_size32', 1), ('uses_dynamic_stack', 0), ('enable_private_segment', 0),
|
||||
('system_sgpr_workgroup_id_x', 1), ('system_sgpr_workgroup_id_y', 1), ('system_sgpr_workgroup_id_z', 0),
|
||||
('system_sgpr_workgroup_info', 0), ('system_vgpr_workitem_id', 0), ('next_free_vgpr', 179),
|
||||
('next_free_sgpr', 16), ('float_round_mode_32', 0), ('float_round_mode_16_64', 0),
|
||||
('float_denorm_mode_32', 3), ('float_denorm_mode_16_64', 3), ('dx10_clamp', 1), ('ieee_mode', 1),
|
||||
('fp16_overflow', 0), ('workgroup_processor_mode', 0), ('memory_ordered', 1), ('forward_progress', 0),
|
||||
('shared_vgpr_count', 0)]
|
||||
|
||||
return '\n'.join([
|
||||
'\t.text', f'\t.amdgcn_target "amdgcn-amd-amdhsa--{self.arch}"',
|
||||
'\t.protected\tkernel', '\t.globl\tkernel', '\t.p2align\t8', '\t.type\tkernel,@function', 'kernel:',
|
||||
*body,
|
||||
'\t.section\t.rodata,"a",@progbits', '\t.p2align\t6, 0x0', '\t.amdhsa_kernel kernel',
|
||||
*[f'\t\t.amdhsa_{k} {v}' for k, v in hsa],
|
||||
'\t.end_amdhsa_kernel', '\t.text', '.Lfunc_end0:', '\t.size\tkernel, .Lfunc_end0-kernel',
|
||||
'\t.amdgpu_metadata', '---', 'amdhsa.kernels:', ' - .args:',
|
||||
*[f' - .address_space: global\n .offset: {i*8}\n .size: 8\n .value_kind: global_buffer' for i in range(3)],
|
||||
f' .group_segment_fixed_size: {lds_size}', ' .kernarg_segment_align: 8',
|
||||
' .kernarg_segment_size: 24', ' .max_flat_workgroup_size: 128', ' .name: kernel',
|
||||
' .private_segment_fixed_size: 0', ' .sgpr_count: 60', ' .symbol: kernel.kd',
|
||||
' .vgpr_count: 179', ' .wavefront_size: 32', f'amdhsa.target: amdgcn-amd-amdhsa--{self.arch}',
|
||||
'amdhsa.version:', ' - 1', ' - 2', '...', '\t.end_amdgpu_metadata'])
|
||||
return self.instructions
|
||||
|
||||
|
||||
# =============================================================================
|
||||
@@ -460,7 +426,7 @@ def build_kernel(arch='gfx1100'):
|
||||
k.emit(s_sendmsg(simm16=3)) # DEALLOC_VGPRS
|
||||
k.emit(s_endpgm())
|
||||
|
||||
return k.to_asm()
|
||||
return k.finalize()
|
||||
|
||||
# =============================================================================
|
||||
# Test harness
|
||||
@@ -474,16 +440,7 @@ def test_matmul():
|
||||
dev = Device[Device.DEFAULT]
|
||||
print(f"Device arch: {dev.renderer.arch}")
|
||||
|
||||
if getenv("STOCK", 0):
|
||||
# Load the stock kernel from amd_seb/kernel8_batched_gmem.s
|
||||
stock_path = Path(__file__).parent / "amd_seb" / "kernel8_batched_gmem.s"
|
||||
asm = stock_path.read_text()
|
||||
print(f"Loaded stock kernel from {stock_path}")
|
||||
else:
|
||||
asm = build_kernel(dev.renderer.arch)
|
||||
|
||||
binary = dev.compiler.compile(asm)
|
||||
print(f"Compiled! Binary size: {len(binary)} bytes")
|
||||
insts = build_kernel(dev.renderer.arch)
|
||||
|
||||
rng = np.random.default_rng(42)
|
||||
a = Tensor(rng.random((N, N), dtype=np.float32) - 0.5)
|
||||
@@ -498,10 +455,10 @@ def test_matmul():
|
||||
def asm_kernel(A:UOp, B:UOp, C:UOp) -> UOp:
|
||||
gidxs = [UOp.special(n, f"gidx{i}") for i,n in enumerate(grid)]
|
||||
lidxs = [UOp.special(n, f"lidx{i}") for i,n in enumerate(local)]
|
||||
sink = UOp.sink(A.base, B.base, C.base, *gidxs, *lidxs, arg=KernelInfo(name=colored("kernel", "cyan"),
|
||||
estimates=Estimates(ops=N*N*N*2, mem=N*N*4*3)))
|
||||
return UOp(Ops.PROGRAM, src=(sink, UOp(Ops.DEVICE, arg=dname), UOp(Ops.LINEAR, src=(*sink.src, sink)), UOp(Ops.SOURCE, arg=asm),
|
||||
UOp(Ops.BINARY, arg=binary)))
|
||||
lds = UOp(Ops.DEFINE_LOCAL, dtypes.uint8.ptr(size=max(LDS_SIZE, 65536//getenv("LIMIT_OCC", 65536)), addrspace=AddrSpace.LOCAL), (), 'lds')
|
||||
sink = UOp.sink(A.base, B.base, C.base, lds, *gidxs, *lidxs, arg=KernelInfo(name=colored("kernel", "cyan"),
|
||||
estimates=Estimates(ops=N*N*N*2, mem=N*N*4*3)))
|
||||
return UOp(Ops.PROGRAM, src=(sink, UOp(Ops.DEVICE, arg=dname), UOp(Ops.LINEAR, src=tuple([UOp(Ops.INS, arg=x) for x in insts]))))
|
||||
c = Tensor.custom_kernel(a, b, c, fxn=asm_kernel)[2]
|
||||
ei = c.schedule()[0].lower()
|
||||
|
||||
@@ -542,6 +499,6 @@ def run_sqtt():
|
||||
print(f"Wrote {len(output)} bytes to /tmp/sqtt_trace.txt")
|
||||
|
||||
if __name__ == "__main__":
|
||||
if getenv("ASM", 0): print(build_kernel(Device[Device.DEFAULT].arch))
|
||||
if getenv("ASM", 0): print("\n".join(str(inst) for inst in build_kernel(Device[Device.DEFAULT].renderer.arch)))
|
||||
elif getenv("SQTT", 0): run_sqtt()
|
||||
else: test_matmul()
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
import os, pathlib
|
||||
import os
|
||||
|
||||
# TODO: there is a timing bug without this
|
||||
os.environ["AMD_AQL"] = "1"
|
||||
|
||||
from tinygrad.device import Device
|
||||
from tinygrad.runtime.support.compiler_amd import HIPCompiler
|
||||
from tinygrad import Tensor, Device
|
||||
from tinygrad.uop.ops import UOp, Ops, KernelInfo
|
||||
from tinygrad.renderer import Estimates
|
||||
from tinygrad.renderer.amd.dsl import Reg, Inst, s, v
|
||||
|
||||
NUM_WORKGROUPS = 96
|
||||
@@ -13,17 +14,22 @@ NUM_WAVES = 2
|
||||
FLOPS_PER_MATMUL = 16*16*16*2
|
||||
INTERNAL_LOOP = 1_000_00
|
||||
INSTRUCTIONS_PER_LOOP = 200
|
||||
DIRECTIVE = ".amdhsa_wavefront_size32 1"
|
||||
|
||||
assemblyTemplate = (pathlib.Path(__file__).parent / "template.s").read_text()
|
||||
|
||||
def repeat(insts:list[Inst], n:int, counter_sreg:Reg) -> bytes:
|
||||
preamble = s_mov_b32(counter_sreg, n).to_bytes()
|
||||
def repeat(insts:list[Inst], n:int, counter_sreg:Reg) -> list[Inst]:
|
||||
insts_bytes = b"".join([inst.to_bytes() for inst in insts])
|
||||
sub_inst, cmp_inst = s_sub_u32(counter_sreg, counter_sreg, 1), s_cmp_lg_i32(counter_sreg, 0)
|
||||
loop_sz = len(insts_bytes) + sub_inst.size() + cmp_inst.size()
|
||||
branch_inst = s_cbranch_scc1(simm16=-((loop_sz // 4) + 1) & 0xFFFF)
|
||||
return preamble + insts_bytes + sub_inst.to_bytes() + cmp_inst.to_bytes() + branch_inst.to_bytes() + s_endpgm().to_bytes()
|
||||
return [s_mov_b32(counter_sreg, n)] + insts + [sub_inst, cmp_inst, branch_inst, s_endpgm()]
|
||||
|
||||
def make_kernel(insts:list[Inst]):
|
||||
def fxn(A:UOp) -> UOp:
|
||||
threads = UOp.special(WAVE_SIZE * NUM_WAVES, "lidx0")
|
||||
gidx = UOp.special(NUM_WORKGROUPS, "gidx0")
|
||||
FLOPs = FLOPS_PER_MATMUL * NUM_WAVES * NUM_WORKGROUPS * INTERNAL_LOOP * INSTRUCTIONS_PER_LOOP
|
||||
sink = UOp.sink(A.base, threads, gidx, arg=KernelInfo("mmapeak", estimates=Estimates(ops=FLOPs, mem=0)))
|
||||
return UOp(Ops.PROGRAM, src=(sink, UOp(Ops.DEVICE, arg="AMD"), UOp(Ops.LINEAR, src=tuple([UOp(Ops.INS, arg=x) for x in insts]))))
|
||||
return fxn
|
||||
|
||||
def launchBenchmark(instruction, vgprIndices, dense=True, accum=False, **kwargs):
|
||||
if accum:
|
||||
@@ -32,16 +38,12 @@ def launchBenchmark(instruction, vgprIndices, dense=True, accum=False, **kwargs)
|
||||
inst = instruction(v[0:vgprIndices[0]], v[vgprIndices[1]:vgprIndices[2]], v[vgprIndices[1]:vgprIndices[2]], 1)
|
||||
else:
|
||||
inst = instruction(v[0:vgprIndices[0]], v[vgprIndices[1]:vgprIndices[2]], v[vgprIndices[3]:vgprIndices[4]], v[vgprIndices[5]])
|
||||
vgprs:set = set()
|
||||
for n,_ in inst._fields:
|
||||
if isinstance(val:=getattr(inst, n), Reg) and val.offset >= v.offset: vgprs |= {val.offset+i for i in range(val.sz)}
|
||||
inst_bytes = repeat([inst for _ in range(INSTRUCTIONS_PER_LOOP)], n=INTERNAL_LOOP, counter_sreg=s[1])
|
||||
inst_hex = "\n".join(" .byte " + ",".join(f"0x{b:02x}" for b in inst_bytes[i:i+16]) for i in range(0, len(inst_bytes), 16)) + "\n"
|
||||
src = assemblyTemplate.replace("INTERNAL_LOOP", str(INTERNAL_LOOP)).replace("INSTRUCTION", inst_hex).replace("VGPR_COUNT", str(len(vgprs)))
|
||||
src = src.replace("DIRECTIVE", DIRECTIVE)
|
||||
lib = COMPILER.compile(src)
|
||||
fxn = DEV.runtime("matmul", lib)
|
||||
elapsed = min([fxn(global_size=(NUM_WORKGROUPS,1,1), local_size=(WAVE_SIZE*NUM_WAVES,1,1), wait=True) for _ in range(2)])
|
||||
insts = repeat([inst for _ in range(INSTRUCTIONS_PER_LOOP)], n=INTERNAL_LOOP, counter_sreg=s[1])
|
||||
fxn = make_kernel(insts)
|
||||
dummy = Tensor.zeros(1).contiguous().realize()
|
||||
out = Tensor.custom_kernel(dummy, fxn=fxn)[0]
|
||||
ei = out.schedule()[-1].lower()
|
||||
elapsed = min([ei.run(wait=True) for _ in range(2)])
|
||||
FLOPs = FLOPS_PER_MATMUL * NUM_WAVES * NUM_WORKGROUPS * INTERNAL_LOOP * INSTRUCTIONS_PER_LOOP
|
||||
print(f"{inst.op_name.lower():<29} : {FLOPs/elapsed/10**12:.2f} T(FL)OPS")
|
||||
|
||||
@@ -49,7 +51,6 @@ if __name__=="__main__":
|
||||
DEV = Device[Device.DEFAULT]
|
||||
arch = DEV.renderer.arch
|
||||
|
||||
COMPILER = HIPCompiler(arch)
|
||||
if arch in {'gfx1100', 'gfx1103', 'gfx1151'}:
|
||||
from tinygrad.runtime.autogen.amd.rdna3.ins import *
|
||||
if arch == 'gfx1103': NUM_WORKGROUPS = 8
|
||||
@@ -91,7 +92,6 @@ if __name__=="__main__":
|
||||
launchBenchmark(v_swmmac_i32_16x16x64_iu4, (7,8,9,10,13,14), False)
|
||||
elif arch == 'gfx950':
|
||||
from tinygrad.runtime.autogen.amd.cdna.ins import *
|
||||
DIRECTIVE = ".amdhsa_accum_offset 4"
|
||||
NUM_WORKGROUPS = 256
|
||||
WAVE_SIZE = 64
|
||||
NUM_WAVES = 4
|
||||
|
||||
@@ -1,33 +0,0 @@
|
||||
.text
|
||||
.globl matmul
|
||||
.p2align 8
|
||||
.type matmul,@function
|
||||
matmul:
|
||||
INSTRUCTION
|
||||
|
||||
.rodata
|
||||
.p2align 6
|
||||
.amdhsa_kernel matmul
|
||||
.amdhsa_next_free_vgpr VGPR_COUNT
|
||||
.amdhsa_next_free_sgpr 3
|
||||
DIRECTIVE
|
||||
.end_amdhsa_kernel
|
||||
|
||||
.amdgpu_metadata
|
||||
---
|
||||
amdhsa.version:
|
||||
- 1
|
||||
- 0
|
||||
amdhsa.kernels:
|
||||
- .name: matmul
|
||||
.symbol: matmul.kd
|
||||
.kernarg_segment_size: 0
|
||||
.group_segment_fixed_size: 0
|
||||
.private_segment_fixed_size: 0
|
||||
.kernarg_segment_align: 4
|
||||
.wavefront_size: 32
|
||||
.sgpr_count: 8
|
||||
.vgpr_count: 32
|
||||
.max_flat_workgroup_size: 1024
|
||||
...
|
||||
.end_amdgpu_metadata
|
||||
@@ -1,23 +1,12 @@
|
||||
import unittest
|
||||
import functools
|
||||
from tinygrad import Tensor, Device, dtypes
|
||||
from tinygrad.uop.ops import UOp, Ops, KernelInfo
|
||||
from tinygrad.renderer import Estimates
|
||||
from tinygrad.runtime.support.compiler_amd import HIPCompiler
|
||||
|
||||
from tinygrad.runtime.autogen.amd.rdna3.ins import *
|
||||
from tinygrad.renderer.amd.dsl import s, v, Inst
|
||||
from test.amd.disasm import disasm as disasm_inst
|
||||
from tinygrad.renderer.amd.dsl import s, v
|
||||
|
||||
def assemble_insts(insts:list[Inst], name:str, arch:str, kernarg_size:int=8) -> tuple[UOp, UOp]:
|
||||
kd = {"kernarg_size":kernarg_size, "user_sgpr_kernarg_segment_ptr":1, "next_free_vgpr":8, "next_free_sgpr":8, "wavefront_size32":1}
|
||||
disasm = "\n".join([disasm_inst(inst) for inst in insts])
|
||||
hsasrc = f".text\n.globl {name}\n.p2align 8\n.type fn_name,@function\n{name}:\n{disasm}\ns_code_end\n"
|
||||
hsasrc += f".rodata\n.p2align 6\n.amdhsa_kernel {name}\n"+"\n".join([f".amdhsa_{k} {v}" for k,v in kd.items()])+"\n.end_amdhsa_kernel"
|
||||
binary = HIPCompiler(arch).compile(hsasrc)
|
||||
return UOp(Ops.SOURCE, arg=disasm), UOp(Ops.BINARY, arg=binary)
|
||||
|
||||
def custom_add_one(A:UOp, arch:str) -> UOp:
|
||||
def custom_add_one(A:UOp) -> UOp:
|
||||
A = A.flatten()
|
||||
assert dtypes.is_float(A.dtype.base), f"buffer dtype must be float32, got {A.dtype}"
|
||||
threads = UOp.special(A.size, "lidx0")
|
||||
@@ -32,10 +21,10 @@ def custom_add_one(A:UOp, arch:str) -> UOp:
|
||||
global_store_b32(addr=v[0], data=v[1], saddr=s[0:1]),
|
||||
s_endpgm(),
|
||||
]
|
||||
sink = UOp.sink(A.base, threads, arg=KernelInfo(name:=f"custom_add_one_{A.size}", estimates=Estimates(ops=A.size, mem=A.size*4*2)))
|
||||
return UOp(Ops.PROGRAM, src=(sink, UOp(Ops.DEVICE, arg="AMD"), UOp(Ops.LINEAR, src=(*sink.src, sink)), *assemble_insts(insts, name, arch)))
|
||||
sink = UOp.sink(A.base, threads, arg=KernelInfo(f"custom_add_one_{A.size}", estimates=Estimates(ops=A.size, mem=A.size*4*2)))
|
||||
return UOp(Ops.PROGRAM, src=(sink, UOp(Ops.DEVICE, arg="AMD"), UOp(Ops.LINEAR, src=tuple([UOp(Ops.INS, arg=x) for x in insts]))))
|
||||
|
||||
def custom_add_var(A:UOp, B:UOp, arch:str) -> UOp:
|
||||
def custom_add_var(A:UOp, B:UOp) -> UOp:
|
||||
A,B = A.flatten(), B.flatten()
|
||||
assert A.dtype.base == dtypes.uint32, f"buffer dtype must be uint32, got {A.dtype}"
|
||||
threads = UOp.special(A.size, "lidx0")
|
||||
@@ -51,15 +40,14 @@ def custom_add_var(A:UOp, B:UOp, arch:str) -> UOp:
|
||||
global_store_b32(addr=v[0], data=v[1], saddr=s[4:5]),
|
||||
s_endpgm(),
|
||||
]
|
||||
sink = UOp.sink(A.base, B.base, var, threads, arg=KernelInfo(name:=f"custom_add_one_{A.size}"))
|
||||
return UOp(Ops.PROGRAM, src=(sink, UOp(Ops.DEVICE, arg="AMD"), UOp(Ops.LINEAR, src=(*sink.src, sink)),
|
||||
*assemble_insts(insts, name, arch, kernarg_size=16)))
|
||||
sink = UOp.sink(A.base, B.base, var, threads, arg=KernelInfo(f"custom_add_var_{A.size}"))
|
||||
return UOp(Ops.PROGRAM, src=(sink, UOp(Ops.DEVICE, arg="AMD"), UOp(Ops.LINEAR, src=tuple([UOp(Ops.INS, arg=x) for x in insts]))))
|
||||
|
||||
@unittest.skipUnless(Device.DEFAULT == "AMD", "requires AMD device")
|
||||
class TestCustomKernel(unittest.TestCase):
|
||||
def test_simple(self):
|
||||
a = Tensor.full((16, 16), 1.).contiguous().realize()
|
||||
a = Tensor.custom_kernel(a, fxn=functools.partial(custom_add_one, arch=Device[Device.DEFAULT].renderer.arch))[0] # type: ignore[attr-defined]
|
||||
a = Tensor.custom_kernel(a, fxn=custom_add_one)[0]
|
||||
ei = a.schedule()[-1].lower()
|
||||
self.assertEqual(ei.prg.estimates.ops, a.numel())
|
||||
self.assertEqual(ei.prg.estimates.mem, a.nbytes()*2)
|
||||
@@ -69,7 +57,7 @@ class TestCustomKernel(unittest.TestCase):
|
||||
def test_variable(self):
|
||||
b = Tensor.full((16, 16), 1, dtype=dtypes.uint32).contiguous().realize()
|
||||
a = Tensor.zeros_like(b).contiguous().realize()
|
||||
a = Tensor.custom_kernel(a, b, fxn=functools.partial(custom_add_var, arch=Device[Device.DEFAULT].renderer.arch))[0] # type: ignore[attr-defined]
|
||||
a = Tensor.custom_kernel(a, b, fxn=custom_add_var)[0]
|
||||
ei = a.schedule()[-1].lower()
|
||||
for i in range(4):
|
||||
ei.run({"var":i})
|
||||
|
||||
@@ -2,12 +2,9 @@
|
||||
# allow define from star imports
|
||||
|
||||
import unittest
|
||||
import functools
|
||||
|
||||
from tinygrad import Device, Tensor
|
||||
from tinygrad.uop.ops import UOp, Ops, KernelInfo
|
||||
from tinygrad.device import Compiler
|
||||
from tinygrad.runtime.support.compiler_amd import HIPCompiler
|
||||
from tinygrad.viz.serve import amdgpu_cfg
|
||||
|
||||
from tinygrad.runtime.autogen.amd.rdna3.ins import *
|
||||
@@ -16,24 +13,17 @@ from tinygrad.renderer.amd.dsl import s
|
||||
# TODO: this belongs to the dsl infrastructure
|
||||
from extra.gemm.amd_asm_matmul import Kernel
|
||||
|
||||
# TODO: shouldn't need compiler once we can output ELF
|
||||
# outputs a text disassembly for humans and a machine readable binary
|
||||
def assemble(name:str, k:Kernel, compiler:Compiler) -> tuple[str, bytes]:
|
||||
src = k.to_asm()
|
||||
return (src, compiler.compile(src))
|
||||
|
||||
def asm_kernel(out:UOp, k:Kernel, name:str, device:str, compiler:Compiler, n_threads:int=1, n_workgroups:int=1) -> UOp:
|
||||
lidx = UOp.special(n_threads, "lidx0")
|
||||
gidx = UOp.special(n_workgroups, "gidx0")
|
||||
sink = UOp.sink(out, lidx, gidx, arg=KernelInfo(name=name))
|
||||
src, lib = assemble(name, k, compiler)
|
||||
return UOp(Ops.PROGRAM, src=(sink, UOp(Ops.DEVICE, arg=device), UOp(Ops.LINEAR, src=(*sink.src, sink)),
|
||||
UOp(Ops.SOURCE, arg=src), UOp(Ops.BINARY, arg=lib)))
|
||||
|
||||
def run_asm(name:str, k:Kernel) -> None:
|
||||
fxn = functools.partial(asm_kernel, k=k, name=name, device=Device.DEFAULT, compiler=HIPCompiler(Device[Device.DEFAULT].renderer.arch))
|
||||
def run_asm(name:str, k:Kernel):
|
||||
insts = k.finalize()
|
||||
def fxn(out:UOp) -> UOp:
|
||||
lidx = UOp.special(1, "lidx0")
|
||||
gidx = UOp.special(1, "gidx0")
|
||||
sink = UOp.sink(out.base, lidx, gidx, arg=KernelInfo(name=name))
|
||||
return UOp(Ops.PROGRAM, src=(sink, UOp(Ops.DEVICE, arg="AMD"), UOp(Ops.LINEAR, src=tuple([UOp(Ops.INS, arg=x) for x in insts]))))
|
||||
out = Tensor.custom_kernel(Tensor.empty(1), fxn=fxn)[0]
|
||||
out.realize()
|
||||
ei = out.schedule()[-1].lower()
|
||||
ei.run()
|
||||
return ei
|
||||
|
||||
@unittest.skipUnless(Device.DEFAULT == "AMD", "only on AMD")
|
||||
class TestCfg(unittest.TestCase):
|
||||
@@ -67,9 +57,8 @@ class TestCfg(unittest.TestCase):
|
||||
k.label("end")
|
||||
k.emit(s_endpgm())
|
||||
k.emit(s_code_end())
|
||||
run_asm("diamond", k)
|
||||
_, lib = assemble("diamond", k, HIPCompiler(Device[Device.DEFAULT].arch))
|
||||
cfg = amdgpu_cfg(lib, Device[Device.DEFAULT].device_props()["gfx_target_version"])["data"]
|
||||
ei = run_asm("diamond", k)
|
||||
cfg = amdgpu_cfg(ei.prg.p.lib, Device[Device.DEFAULT].device_props()["gfx_target_version"])["data"]
|
||||
self.assertEqual(len(cfg["blocks"]), 5)
|
||||
edge_count = sum(len(v) for v in cfg["paths"].values())
|
||||
self.assertEqual(edge_count, 5)
|
||||
|
||||
@@ -21,6 +21,7 @@ from tinygrad.codegen.opt.postrange import apply_opts, pm_make_images
|
||||
from tinygrad.codegen.simplify import pm_simplify_ranges, pm_flatten_range, pm_split_ranges, pm_load_collapse
|
||||
from tinygrad.schedule.rangeify import pm_add_buffers_local, rangeify_codegen, pm_mops, pm_syntactic_sugar
|
||||
from tinygrad.codegen.late.linearizer import CFGContext, pm_split_ends, pm_add_control_flow, linearize
|
||||
from tinygrad.renderer.amd.elf import do_assemble_amd
|
||||
|
||||
def full_rewrite_to_sink(sink:UOp, ren:Renderer|None=None, optimize:bool=True) -> UOp:
|
||||
if ren is None: ren = Renderer()
|
||||
@@ -144,6 +145,7 @@ def do_compile(ctx:Renderer, prg:UOp, source:UOp) -> UOp|None:
|
||||
|
||||
pm_to_program = PatternMatcher([
|
||||
(UPat(Ops.PROGRAM, src=(UPat(Ops.SINK, name="sink"), UPat(Ops.DEVICE)), name="prg"), do_linearize),
|
||||
(UPat(Ops.PROGRAM, src=(UPat(), UPat(Ops.DEVICE), UPat(Ops.LINEAR, src=UPat(Ops.INS), name="lin")), name="prg"), do_assemble_amd),
|
||||
(UPat(Ops.PROGRAM, src=(UPat(), UPat(Ops.DEVICE), UPat(Ops.LINEAR, name="lin")), name="prg"), do_render),
|
||||
(UPat(Ops.PROGRAM, src=(UPat(), UPat(Ops.DEVICE), UPat(Ops.LINEAR), UPat(Ops.SOURCE, name="source")), name="prg"), do_compile),
|
||||
])
|
||||
|
||||
@@ -38,7 +38,6 @@ class Estimates:
|
||||
elif u.op is Ops.IF:
|
||||
dont_count = dont_count.union(u.src[0].toposort())
|
||||
for u in uops:
|
||||
if u.op is Ops.SINK and isinstance(u.arg, KernelInfo) and u.arg.estimates is not None: return u.arg.estimates
|
||||
if u.op in {Ops.LOAD, Ops.STORE}:
|
||||
buf = u
|
||||
while len(buf.src): buf = buf.src[0]
|
||||
@@ -82,6 +81,7 @@ class ProgramSpec:
|
||||
|
||||
@functools.cached_property
|
||||
def estimates(self) -> Estimates:
|
||||
if self.ast.op is Ops.SINK and isinstance(self.ast.arg, KernelInfo) and self.ast.arg.estimates is not None: return self.ast.arg.estimates
|
||||
return Estimates() if self.uops is None else Estimates.from_uops(self.uops, ignore_indexing=True)
|
||||
|
||||
@functools.cached_property
|
||||
@@ -118,7 +118,7 @@ class ProgramSpec:
|
||||
ins: list[int] = []
|
||||
global_size: list[int] = [1, 1, 1]
|
||||
local_size: list[int]|None = [1, 1, 1]
|
||||
for u in uops:
|
||||
for u in sink.toposort():
|
||||
if u.op is Ops.DEFINE_VAR: _vars.append(u)
|
||||
if u.op is Ops.PARAM: _globals.append(u.arg)
|
||||
if u.op in (Ops.STORE, Ops.LOAD):
|
||||
|
||||
@@ -99,7 +99,7 @@ bits = _Bits()
|
||||
|
||||
class BitField:
|
||||
name: str | None
|
||||
def __init__(self, hi: int, lo: int, default: int = 0):
|
||||
def __init__(self, hi: int, lo: int, default = 0):
|
||||
self.hi, self.lo, self.default, self.name, self.mask = hi, lo, default, None, (1 << (hi - lo + 1)) - 1
|
||||
def __set_name__(self, owner, name: str): self.name = name
|
||||
def __eq__(self, other) -> 'FixedBitField': # type: ignore[override]
|
||||
|
||||
127
tinygrad/renderer/amd/elf.py
Normal file
127
tinygrad/renderer/amd/elf.py
Normal file
@@ -0,0 +1,127 @@
|
||||
# minimal amdgpu elf packer
|
||||
import ctypes
|
||||
from tinygrad.helpers import ceildiv, round_up
|
||||
from tinygrad.uop.ops import UOp, Ops
|
||||
from tinygrad.runtime.autogen import amdgpu_kd, hsa, libc
|
||||
from tinygrad.renderer.amd.dsl import Reg, FixedBitField
|
||||
|
||||
# instructions used for padding
|
||||
from tinygrad.runtime.autogen.amd.rdna3.ins import s_code_end # same encoding as RDNA4
|
||||
from tinygrad.runtime.autogen.amd.cdna.ins import s_nop as s_nop_cdna
|
||||
|
||||
def put(dst:bytearray, off:int, data:bytes) -> None:
|
||||
end = off + len(data)
|
||||
if end > len(dst): raise ValueError("write past end of buffer")
|
||||
dst[off:end] = data
|
||||
|
||||
def create_elf(prg:bytes, kd:dict, arch:str) -> bytes:
|
||||
is_cdna, is_rdna4 = arch == "cdna", arch == "rdna4"
|
||||
padding_inst = (s_nop_cdna(0) if is_cdna else s_code_end()).to_bytes()
|
||||
text = prg + padding_inst * ((hsa.AMD_ISA_ALIGN_BYTES - len(prg) % hsa.AMD_ISA_ALIGN_BYTES) % hsa.AMD_ISA_ALIGN_BYTES)
|
||||
text_offset = round_up(ctypes.sizeof(libc.Elf64_Ehdr), hsa.AMD_ISA_ALIGN_BYTES)
|
||||
rodata_offset = text_offset + len(text)
|
||||
|
||||
# ** pack rodata object
|
||||
desc = amdgpu_kd.llvm_amdhsa_kernel_descriptor_t()
|
||||
desc.group_segment_fixed_size = kd.get("group_segment_fixed_size", 0)
|
||||
desc.private_segment_fixed_size = kd.get("private_segment_fixed_size", 0)
|
||||
desc.kernarg_size = kd.get("kernarg_size", 0)
|
||||
desc.kernel_code_entry_byte_offset = text_offset-rodata_offset
|
||||
# rsrc1
|
||||
vgpr_granule = max(0, (kd["next_free_vgpr"] + 7) // 8 - 1)
|
||||
# CDNA: add 6 for VCC(2) + FLAT_SCRATCH(2) + XNACK_MASK(2)
|
||||
# next_free_sgpr is unused in RDNA
|
||||
# NOTE: CU mode is the default, it seems faster and simpler
|
||||
sgpr_granule = max(0, ceildiv(kd["next_free_sgpr"] + 6, 8) - 1) if is_cdna else 0
|
||||
desc.compute_pgm_rsrc1 = (vgpr_granule << amdgpu_kd.COMPUTE_PGM_RSRC1_GRANULATED_WORKITEM_VGPR_COUNT_SHIFT |
|
||||
sgpr_granule << amdgpu_kd.COMPUTE_PGM_RSRC1_GRANULATED_WAVEFRONT_SGPR_COUNT_SHIFT |
|
||||
kd.get("float_round_mode_32", 0) << amdgpu_kd.COMPUTE_PGM_RSRC1_FLOAT_ROUND_MODE_32_SHIFT |
|
||||
kd.get("float_round_mode_16_64", 0) << amdgpu_kd.COMPUTE_PGM_RSRC1_FLOAT_ROUND_MODE_16_64_SHIFT |
|
||||
kd.get("float_denorm_mode_32", 0) << amdgpu_kd.COMPUTE_PGM_RSRC1_FLOAT_DENORM_MODE_32_SHIFT |
|
||||
kd.get("float_denorm_mode_16_64", 3) << amdgpu_kd.COMPUTE_PGM_RSRC1_FLOAT_DENORM_MODE_16_64_SHIFT |
|
||||
kd.get("dx10_clamp", 0 if is_rdna4 else 1) << amdgpu_kd.COMPUTE_PGM_RSRC1_GFX6_GFX11_ENABLE_DX10_CLAMP_SHIFT |
|
||||
kd.get("ieee_mode", 0 if is_rdna4 else 1) << amdgpu_kd.COMPUTE_PGM_RSRC1_GFX6_GFX11_ENABLE_IEEE_MODE_SHIFT |
|
||||
kd.get("fp16_overflow", 0) << amdgpu_kd.COMPUTE_PGM_RSRC1_GFX9_PLUS_FP16_OVFL_SHIFT |
|
||||
(0 if is_cdna else kd.get("workgroup_processor_mode", 0)) << amdgpu_kd.COMPUTE_PGM_RSRC1_GFX10_PLUS_WGP_MODE_SHIFT |
|
||||
(0 if is_cdna else kd.get("memory_ordered", 1)) << amdgpu_kd.COMPUTE_PGM_RSRC1_GFX10_PLUS_MEM_ORDERED_SHIFT |
|
||||
(0 if is_cdna else kd.get("forward_progress", 0)) << amdgpu_kd.COMPUTE_PGM_RSRC1_GFX10_PLUS_FWD_PROGRESS_SHIFT)
|
||||
# rsrc2
|
||||
desc.compute_pgm_rsrc2 = (kd.get("enable_private_segment", 0) << amdgpu_kd.COMPUTE_PGM_RSRC2_ENABLE_PRIVATE_SEGMENT_SHIFT |
|
||||
kd.get("user_sgpr_count", 0) << amdgpu_kd.COMPUTE_PGM_RSRC2_USER_SGPR_COUNT_SHIFT |
|
||||
kd.get("system_sgpr_workgroup_id_x", 1) << amdgpu_kd.COMPUTE_PGM_RSRC2_ENABLE_SGPR_WORKGROUP_ID_X_SHIFT |
|
||||
kd.get("system_sgpr_workgroup_id_y", 0) << amdgpu_kd.COMPUTE_PGM_RSRC2_ENABLE_SGPR_WORKGROUP_ID_Y_SHIFT |
|
||||
kd.get("system_sgpr_workgroup_id_z", 0) << amdgpu_kd.COMPUTE_PGM_RSRC2_ENABLE_SGPR_WORKGROUP_ID_Z_SHIFT |
|
||||
kd.get("system_sgpr_workgroup_info", 0) << amdgpu_kd.COMPUTE_PGM_RSRC2_ENABLE_SGPR_WORKGROUP_INFO_SHIFT |
|
||||
kd.get("system_vgpr_workitem_id", 0) << amdgpu_kd.COMPUTE_PGM_RSRC2_ENABLE_VGPR_WORKITEM_ID_SHIFT)
|
||||
# rsrc3
|
||||
if is_cdna:
|
||||
amdhsa_accum_offset = ((kd.get("accum_offset", 4) // 4) - 1) & amdgpu_kd.COMPUTE_PGM_RSRC3_GFX90A_ACCUM_OFFSET
|
||||
desc.compute_pgm_rsrc3 = amdhsa_accum_offset << amdgpu_kd.COMPUTE_PGM_RSRC3_GFX90A_ACCUM_OFFSET_SHIFT
|
||||
else:
|
||||
desc.compute_pgm_rsrc3 = kd.get("shared_vgpr_count", 0) << amdgpu_kd.COMPUTE_PGM_RSRC3_GFX10_GFX11_SHARED_VGPR_COUNT_SHIFT
|
||||
# kernel code properties
|
||||
desc.kernel_code_properties = (kd.get("user_sgpr_dispatch_ptr", 0) << amdgpu_kd.KERNEL_CODE_PROPERTY_ENABLE_SGPR_DISPATCH_PTR_SHIFT |
|
||||
kd.get("user_sgpr_queue_ptr", 0) << amdgpu_kd.KERNEL_CODE_PROPERTY_ENABLE_SGPR_QUEUE_PTR_SHIFT |
|
||||
kd.get("user_sgpr_kernarg_segment_ptr", 0) << amdgpu_kd.KERNEL_CODE_PROPERTY_ENABLE_SGPR_KERNARG_SEGMENT_PTR_SHIFT |
|
||||
kd.get("user_sgpr_dispatch_id", 0) << amdgpu_kd.KERNEL_CODE_PROPERTY_ENABLE_SGPR_DISPATCH_ID_SHIFT |
|
||||
kd.get("user_sgpr_private_segment_size",0) << amdgpu_kd.KERNEL_CODE_PROPERTY_ENABLE_SGPR_PRIVATE_SEGMENT_SIZE_SHIFT |
|
||||
kd.get("wavefront_size32", 0 if is_cdna else 1) << amdgpu_kd.KERNEL_CODE_PROPERTY_ENABLE_WAVEFRONT_SIZE32_SHIFT |
|
||||
kd.get("uses_dynamic_stack", 0) << amdgpu_kd.KERNEL_CODE_PROPERTY_USES_DYNAMIC_STACK_SHIFT)
|
||||
rodata = bytes(desc)
|
||||
|
||||
# ** pack elf sections
|
||||
sh_names:list[int] = []
|
||||
strtab = bytearray(b"\x00")
|
||||
for name in [".text", ".rodata", ".strtab"]:
|
||||
sh_names.append(len(strtab))
|
||||
strtab += name.encode("ascii") + b"\x00"
|
||||
|
||||
rodata_offset = round_up(text_offset+(text_size:=len(text)), hsa.AMD_KERNEL_CODE_ALIGN_BYTES)
|
||||
strtab_offset = rodata_offset+(rodata_size:=len(rodata))
|
||||
shdr_offset = strtab_offset+(strtab_size:=len(strtab))
|
||||
|
||||
sections = [(libc.SHT_PROGBITS, libc.SHF_ALLOC | libc.SHF_EXECINSTR, text_offset, text_offset, text_size),
|
||||
(libc.SHT_PROGBITS, libc.SHF_ALLOC, rodata_offset, rodata_offset, rodata_size),
|
||||
(libc.SHT_STRTAB, 0, 0, strtab_offset, strtab_size)]
|
||||
shdrs = (libc.Elf64_Shdr * len(sections))()
|
||||
for i,s in enumerate(sections): shdrs[i] = libc.Elf64_Shdr(sh_names[i], *s)
|
||||
|
||||
ehdr = libc.Elf64_Ehdr()
|
||||
ehdr.e_shoff, ehdr.e_shnum, ehdr.e_shstrndx = shdr_offset, len(sections), 2
|
||||
|
||||
elf = bytearray(shdr_offset + ctypes.sizeof(shdrs))
|
||||
put(elf, 0, bytes(ehdr))
|
||||
put(elf, text_offset, text)
|
||||
put(elf, rodata_offset, rodata)
|
||||
put(elf, strtab_offset, strtab)
|
||||
put(elf, shdr_offset, bytes(shdrs))
|
||||
return bytes(elf)
|
||||
|
||||
_arch_map = {"gfx9": "cdna", "gfx10": "rdna3", "gfx11": "rdna3", "gfx12": "rdna4"}
|
||||
def do_assemble_amd(ctx, prg:UOp, lin:UOp) -> UOp:
|
||||
insts = [u.arg for u in lin.src]
|
||||
# scan for max vgpr/sgpr
|
||||
max_vgpr, max_sgpr = 0, 0
|
||||
for inst in insts:
|
||||
for name, field in inst._fields:
|
||||
if isinstance(field, FixedBitField): continue
|
||||
val = getattr(inst, name)
|
||||
if not isinstance(val, Reg): continue
|
||||
if 256 <= val.offset < 512: max_vgpr = max(max_vgpr, (val.offset - 256) + val.sz)
|
||||
elif val.offset < 106: max_sgpr = max(max_sgpr, val.offset + val.sz)
|
||||
# scan sink for metadata
|
||||
sink, n_bufs, n_vars, lds_size, gids = prg.src[0], 0, 0, 0, set()
|
||||
for u in sink.toposort():
|
||||
if u.op is Ops.PARAM: n_bufs += 1
|
||||
elif u.op is Ops.DEFINE_VAR: n_vars += 1
|
||||
elif u.op is Ops.DEFINE_LOCAL: lds_size += u.ptrdtype.size * u.ptrdtype.base.itemsize
|
||||
elif u.op is Ops.SPECIAL and u.arg.startswith("gidx"): gids.add(int(u.arg[-1]))
|
||||
src = "\n".join(str(inst) for inst in insts)
|
||||
code_bytes = b"".join(inst.to_bytes() for inst in insts)
|
||||
arch = next(v for k, v in _arch_map.items() if ctx.arch.startswith(k))
|
||||
kd = {"kernarg_size":n_bufs*8+n_vars*4, "group_segment_fixed_size":lds_size,
|
||||
"user_sgpr_kernarg_segment_ptr":1, "user_sgpr_count":2,
|
||||
"system_sgpr_workgroup_id_x":int(0 in gids), "system_sgpr_workgroup_id_y":int(1 in gids), "system_sgpr_workgroup_id_z":int(2 in gids),
|
||||
"next_free_vgpr":round_up(max_vgpr, 8), "next_free_sgpr":round_up(max_sgpr, 8)}
|
||||
binary = create_elf(code_bytes, kd, arch)
|
||||
return prg.replace(src=prg.src[:3]+(UOp(Ops.SOURCE, arg=src), UOp(Ops.BINARY, arg=binary)))
|
||||
@@ -23,7 +23,7 @@ def elf_loader(blob:bytes, force_section_align:int=1, link_libs:list[str]|None=N
|
||||
def _to_carray(sh, ctype): return (ctype * (sh.header.sh_size // sh.header.sh_entsize)).from_buffer_copy(sh.content)
|
||||
rel = [(sh, sh.name[4:], _to_carray(sh, libc.Elf64_Rel)) for sh in sections if sh.header.sh_type == libc.SHT_REL]
|
||||
rela = [(sh, sh.name[5:], _to_carray(sh, libc.Elf64_Rela)) for sh in sections if sh.header.sh_type == libc.SHT_RELA]
|
||||
symtab = [_to_carray(sh, libc.Elf64_Sym) for sh in sections if sh.header.sh_type == libc.SHT_SYMTAB][0]
|
||||
symtab = next((_to_carray(sh, libc.Elf64_Sym) for sh in sections if sh.header.sh_type == libc.SHT_SYMTAB), None)
|
||||
progbits = [sh for sh in sections if sh.header.sh_type == libc.SHT_PROGBITS]
|
||||
|
||||
# Prealloc image for all fixed addresses.
|
||||
@@ -39,7 +39,7 @@ def elf_loader(blob:bytes, force_section_align:int=1, link_libs:list[str]|None=N
|
||||
for sh, trgt_sh_name, c_rels in rel + rela:
|
||||
if trgt_sh_name == ".eh_frame": continue
|
||||
target_image_off = next(tsh for tsh in sections if tsh.name == trgt_sh_name).header.sh_addr
|
||||
rels = [(r.r_offset, symtab[libc.ELF64_R_SYM(r.r_info)], libc.ELF64_R_TYPE(r.r_info), getattr(r, "r_addend", 0)) for r in c_rels]
|
||||
rels = [(r.r_offset, unwrap(symtab)[libc.ELF64_R_SYM(r.r_info)], libc.ELF64_R_TYPE(r.r_info), getattr(r, "r_addend", 0)) for r in c_rels]
|
||||
relocs += [(target_image_off + roff, link_sym(_strtab(sh_strtab, sym.st_name), link_libs or []) if sym.st_shndx == 0 else
|
||||
sections[sym.st_shndx].header.sh_addr + sym.st_value, rtype, raddend) for roff, sym, rtype, raddend in rels]
|
||||
|
||||
|
||||
@@ -76,6 +76,9 @@ class Ops(FastEnum):
|
||||
# CUSTOM/CUSTOMI are used to output strings into codegen. the I makes the string inline
|
||||
CUSTOM = auto(); CUSTOMI = auto()
|
||||
|
||||
# INS is a machine instruction
|
||||
INS = auto()
|
||||
|
||||
# ** 6 -- ops that don't exist in programs **
|
||||
|
||||
# tensor graph ops
|
||||
|
||||
@@ -209,7 +209,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
||||
# late ops don't have shape
|
||||
case Ops.UNIQUE | Ops.LUNIQUE | Ops.DEVICE | Ops.RANGE | Ops.LOAD | Ops.IF | Ops.BARRIER | Ops.CUSTOM | Ops.CUSTOMI | \
|
||||
Ops.VECTORIZE | Ops.VCONST | Ops.GEP | Ops.SPECIAL | Ops.UNROLL | Ops.CONTRACT | Ops.SINK | \
|
||||
Ops.LINEAR | Ops.PROGRAM | Ops.SOURCE | Ops.BINARY:
|
||||
Ops.LINEAR | Ops.PROGRAM | Ops.SOURCE | Ops.BINARY | Ops.INS:
|
||||
return None
|
||||
|
||||
case Ops.INDEX:
|
||||
|
||||
@@ -177,6 +177,9 @@ shared_codegen_spec = PatternMatcher([
|
||||
# CUSTOM (inline and non inline)
|
||||
(UPat((Ops.CUSTOMI, Ops.CUSTOM)), lambda: True),
|
||||
|
||||
# assembly instruction
|
||||
(UPat(Ops.INS), lambda: True),
|
||||
|
||||
# INDEX (2-arg and 3-arg with bool gate)
|
||||
(UPat(GroupOp.Defines|{Ops.AFTER}, name="buf").index(UPat.var("idx")), validate_index),
|
||||
(UPat(Ops.INDEX, src=(UPat(GroupOp.Defines|{Ops.AFTER}, name="buf"), UPat.var("idx"), UPat.var("gate", dtype=dtypes.bool))), validate_index),
|
||||
|
||||
@@ -186,6 +186,7 @@ const waveColor = (op) => {
|
||||
: op.includes("LOAD") || op === "SMEM" ? "LOAD" : op.includes("STORE") ? "STORE" : op;
|
||||
ret = WAVE_COLORS[cat] ?? "#ffffff";
|
||||
if (op.includes("OTHER_") || op.includes("_ALT")) { ret = darkenHex(ret, 75) }
|
||||
if (op.includes("LDS_")) { ret = darkenHex(ret, 25) }
|
||||
return ret
|
||||
};
|
||||
const colorScheme = {TINY:new Map([["Schedule","#1b5745"],["get_program","#1d2e62"],["compile","#63b0cd"],["DEFAULT","#354f52"]]),
|
||||
|
||||
@@ -45,7 +45,7 @@ from tinygrad.dtype import dtypes
|
||||
uops_colors = {Ops.LOAD: "#ffc0c0", Ops.STORE: "#87CEEB", Ops.CONST: "#e0e0e0", Ops.VCONST: "#e0e0e0", Ops.REDUCE: "#FF5B5B",
|
||||
Ops.PARAM:"#cb9037", **{x:"#f2cb91" for x in {Ops.DEFINE_LOCAL, Ops.DEFINE_REG}}, Ops.REDUCE_AXIS: "#FF6B6B",
|
||||
Ops.RANGE: "#c8a0e0", Ops.ASSIGN: "#909090", Ops.BARRIER: "#ff8080", Ops.IF: "#c8b0c0", Ops.SPECIAL: "#c0c0ff",
|
||||
Ops.INDEX: "#cef263", Ops.WMMA: "#efefc0", Ops.MULTI: "#f6ccff",
|
||||
Ops.INDEX: "#cef263", Ops.WMMA: "#efefc0", Ops.MULTI: "#f6ccff", Ops.INS: "#eec4ff",
|
||||
**{x:"#D8F9E4" for x in GroupOp.Movement}, **{x:"#ffffc0" for x in GroupOp.ALU}, Ops.THREEFRY:"#ffff80",
|
||||
Ops.BUFFER_VIEW: "#E5EAFF", Ops.BUFFER: "#B0BDFF", Ops.COPY: "#a040a0", Ops.ENCDEC: "#bf71b6",
|
||||
Ops.CALL: "#00B7C8", Ops.PARAM: "#14686F", Ops.SOURCE: "#c0c0c0", Ops.LINEAR: "#808080", Ops.BINARY: "#404040",
|
||||
|
||||
Reference in New Issue
Block a user