mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
mi350x 1tflop bf16 gemm in extra (#13702)
This commit is contained in:
1
extra/gemm/.gitignore
vendored
1
extra/gemm/.gitignore
vendored
@@ -1,3 +1,2 @@
|
||||
*.s
|
||||
*.ll
|
||||
fp32_sgemm_amd
|
||||
|
||||
15338
extra/gemm/asm/gemm.s
Normal file
15338
extra/gemm/asm/gemm.s
Normal file
File diff suppressed because it is too large
Load Diff
62
extra/gemm/asm/test.py
Normal file
62
extra/gemm/asm/test.py
Normal file
@@ -0,0 +1,62 @@
|
||||
# Run assembly on the AMD runtime and check correctness
|
||||
# VIZ=2 to profile
|
||||
import pathlib
|
||||
from tinygrad import Tensor, Device, dtypes
|
||||
from tinygrad.engine.realize import ExecItem, CompiledRunner
|
||||
from tinygrad.renderer import ProgramSpec
|
||||
from tinygrad.uop.ops import track_rewrites, UOp
|
||||
from tinygrad.helpers import TracingKey
|
||||
|
||||
fp = pathlib.Path(__file__).parent/"gemm.s"
|
||||
|
||||
# ** generate inputs on CPU
|
||||
|
||||
N = 8192
|
||||
scale = 10.0
|
||||
|
||||
import torch
|
||||
torch.manual_seed(0)
|
||||
A = (torch.randn(N, N, dtype=torch.float32, device="cpu") / scale).to(torch.bfloat16).contiguous()
|
||||
B = (torch.randn(N, N, dtype=torch.float32, device="cpu") / scale).to(torch.bfloat16).contiguous()
|
||||
Bt = B.t().contiguous() # transpose B for the baseline gemm
|
||||
C_torch = A@Bt
|
||||
|
||||
# ** copy buffers to AMD
|
||||
|
||||
# input creation and validation run on the copy engine for simpler tracing
|
||||
|
||||
def from_torch(t:torch.Tensor) -> Tensor:
|
||||
return Tensor.from_blob(t.data_ptr(), t.shape, dtype=dtypes.bfloat16, device="cpu").to(Device.DEFAULT).realize()
|
||||
|
||||
C_tiny = Tensor.matmul(from_torch(A), from_torch(Bt), dtype=dtypes.float32).cast(dtypes.bfloat16)
|
||||
C_asm = Tensor.empty_like(C_tiny)
|
||||
C_asm.uop.buffer.allocate()
|
||||
|
||||
# ** run gemms
|
||||
|
||||
@track_rewrites(name=lambda *args,ret,**kwargs: TracingKey(ret.name, (ret.function_name,), ret=ret))
|
||||
def get_asm_gemm(ast:UOp, fp:pathlib.Path) -> ProgramSpec:
|
||||
src = fp.read_text()
|
||||
lib = Device[Device.DEFAULT].compiler.compile(src)
|
||||
return ProgramSpec("gemm", src, Device.DEFAULT, ast, lib=lib, global_size=[1024, 1, 1], local_size=[256, 1, 1], globals=[0, 1, 2])
|
||||
|
||||
sched = C_tiny.schedule()
|
||||
assert len(sched) == 1
|
||||
eis:list[ExecItem] = [sched[-1].lower()]
|
||||
ast = eis[0].ast
|
||||
prg = get_asm_gemm(ast, fp)
|
||||
eis.append(ExecItem(ast, [C_asm.uop.buffer, from_torch(B).uop.buffer, from_torch(A).uop.buffer], prg=CompiledRunner(prg)))
|
||||
|
||||
for ei in eis: ei.run(wait=True)
|
||||
|
||||
# ** correctness
|
||||
|
||||
import ctypes
|
||||
|
||||
def torch_bf16(t:Tensor) -> torch.tensor:
|
||||
asm_out = t.to("cpu").realize().uop.buffer._buf
|
||||
buf = (ctypes.c_uint16*C_asm.uop.size).from_address(asm_out.va_addr)
|
||||
return torch.frombuffer(buf, dtype=torch.bfloat16, count=C_asm.uop.size).reshape(C_asm.shape)
|
||||
|
||||
assert torch.allclose(torch_bf16(C_asm), C_torch, rtol=1e-2, atol=1e-3)
|
||||
assert torch.allclose(torch_bf16(C_tiny), C_torch, rtol=1e-2, atol=1e-3)
|
||||
179
extra/gemm/asm/unpack_kd.py
Normal file
179
extra/gemm/asm/unpack_kd.py
Normal file
@@ -0,0 +1,179 @@
|
||||
# unpack the complete kernel descriptor of an amdgpu ELF of for gfx950
|
||||
# https://rocm.docs.amd.com/projects/llvm-project/en/latest/LLVM/llvm/html/AMDGPUUsage.html#code-object-v3-kernel-descriptor
|
||||
import struct, pathlib
|
||||
from tinygrad.runtime.support.elf import elf_loader
|
||||
|
||||
def bits(x, lo, hi): return (x >> lo) & ((1 << (hi - lo + 1)) - 1)
|
||||
def assert_zero(x, lo, hi): assert bits(x, lo, hi) == 0
|
||||
|
||||
with open(fp:=pathlib.Path(__file__).parent/"lib", "rb") as f:
|
||||
lib = f.read()
|
||||
|
||||
image, sections, relocs = elf_loader(lib)
|
||||
rodata_entry = next((sh.header.sh_addr for sh in sections if sh.name == ".rodata"))
|
||||
|
||||
# rodata is exactly 64 bytes
|
||||
kd = image[rodata_entry:rodata_entry+64]
|
||||
desc = int.from_bytes(kd, byteorder="little")
|
||||
|
||||
group_segment_fixed_size = bits(desc, 0, 31)
|
||||
private_segment_fixed_size = bits(desc, 32, 63)
|
||||
kernarg_size = bits(desc, 64, 95)
|
||||
reserved_127_96 = bits(desc, 96, 127)
|
||||
assert reserved_127_96 == 0
|
||||
|
||||
print("GROUP_SEGMENT_FIXED_SIZE:", group_segment_fixed_size)
|
||||
print("PRIVATE_SEGMENT_FIXED_SIZE:", private_segment_fixed_size)
|
||||
print("KERNARG_SIZE:", kernarg_size)
|
||||
print("RESERVED 127:96:", reserved_127_96)
|
||||
|
||||
entry_off = bits(desc, 128, 191)
|
||||
|
||||
# sign-extend manually if needed
|
||||
if entry_off & (1 << 63):
|
||||
entry_off -= 1 << 64
|
||||
|
||||
print("KERNEL_CODE_ENTRY_BYTE_OFFSET:", entry_off)
|
||||
|
||||
kd_addr = 0x1840
|
||||
entry_addr = kd_addr + entry_off
|
||||
|
||||
print("Computed entry address: 0x%016x" % entry_addr)
|
||||
print("256B aligned:", entry_addr % 256 == 0)
|
||||
|
||||
pgm_rsrc3 = bits(desc, 352, 383)
|
||||
pgm_rsrc1 = bits(desc, 384, 415)
|
||||
pgm_rsrc2 = bits(desc, 416, 447)
|
||||
|
||||
print("COMPUTE_PGM_RSRC3: 0x%08x" % pgm_rsrc3)
|
||||
print("COMPUTE_PGM_RSRC1: 0x%08x" % pgm_rsrc1)
|
||||
print("COMPUTE_PGM_RSRC2: 0x%08x" % pgm_rsrc2)
|
||||
|
||||
# rsrc 3
|
||||
|
||||
accum_offset_raw = bits(pgm_rsrc3, 0, 5)
|
||||
assert_zero(pgm_rsrc3, 6, 15)
|
||||
tg_split = bits(pgm_rsrc3, 16, 16)
|
||||
accum_offset_vgprs = (accum_offset_raw + 1) * 4
|
||||
print("RSRC3.ACCUM_OFFSET (AccVGPR index):", accum_offset_vgprs)
|
||||
print("RSRC3.TG_SPLIT:", tg_split)
|
||||
|
||||
# rsrc 1
|
||||
|
||||
vgpr_gran = bits(pgm_rsrc1, 0, 5)
|
||||
sgpr_gran = bits(pgm_rsrc1, 6, 9)
|
||||
assert_zero(pgm_rsrc1, 27, 28)
|
||||
|
||||
# NOTE: this is vgprs + agprs
|
||||
vgprs_used = (vgpr_gran + 1) * 8
|
||||
assert 0 <= vgprs_used <= 512
|
||||
|
||||
k = sgpr_gran // 2
|
||||
sgprs_used = (k + 1) * 16
|
||||
|
||||
print("RSRC1.VGPRS:", vgprs_used)
|
||||
print("RSRC1.SGPRS:", sgprs_used)
|
||||
|
||||
assert_zero(pgm_rsrc1, 10, 11)
|
||||
|
||||
float_round_mode_32 = bits(pgm_rsrc1, 12, 13)
|
||||
float_round_mode_16_64 = bits(pgm_rsrc1, 15, 14)
|
||||
float_denorm_mode_32 = bits(pgm_rsrc1, 16, 17)
|
||||
float_denorm_mode_16_64 = bits(pgm_rsrc1, 18, 19)
|
||||
|
||||
priv = bits(pgm_rsrc1, 20, 20)
|
||||
assert priv == 0
|
||||
enable_dx10_clamp_wg_rr_en = bits(pgm_rsrc1, 21, 21)
|
||||
debug_mode = bits(pgm_rsrc1, 22, 22)
|
||||
enable_ieee_mode = bits(pgm_rsrc1, 23, 23)
|
||||
bulky = bits(pgm_rsrc1, 24, 24)
|
||||
assert bulky == 0
|
||||
cdbg_user = bits(pgm_rsrc1, 25, 25)
|
||||
assert cdbg_user == 0
|
||||
fp16_ovfl = bits(pgm_rsrc1, 26, 26)
|
||||
assert_zero(pgm_rsrc1, 27, 28) # reserved
|
||||
assert_zero(pgm_rsrc1, 29, 29) # WGP_MODE (reserved on gfx9)
|
||||
assert_zero(pgm_rsrc1, 30, 30) # MEM_ORDERED (reserved on gfx9)
|
||||
assert_zero(pgm_rsrc1, 31, 31) # FWD_PROGRESS (reserved on gfx9)
|
||||
|
||||
# rsrc 2
|
||||
|
||||
enable_private_segment = bits(pgm_rsrc2, 0, 0) # SCRATCH_EN
|
||||
user_sgpr_count = bits(pgm_rsrc2, 1, 5) # USER_SGPR
|
||||
enable_trap_handler = bits(pgm_rsrc2, 6, 6) # TRAP_PRESENT (must be 0 here)
|
||||
assert enable_trap_handler == 0
|
||||
|
||||
enable_sgpr_workgroup_id_x = bits(pgm_rsrc2, 7, 7)
|
||||
enable_sgpr_workgroup_id_y = bits(pgm_rsrc2, 8, 8)
|
||||
enable_sgpr_workgroup_id_z = bits(pgm_rsrc2, 9, 9)
|
||||
enable_sgpr_workgroup_info = bits(pgm_rsrc2, 10, 10)
|
||||
|
||||
enable_vgpr_workitem_id = bits(pgm_rsrc2, 11, 12) # TIDIG_CMP_CNT enum (0..3)
|
||||
|
||||
enable_exception_address_watch = bits(pgm_rsrc2, 13, 13)
|
||||
assert enable_exception_address_watch == 0
|
||||
enable_exception_memory = bits(pgm_rsrc2, 14, 14)
|
||||
assert enable_exception_memory == 0
|
||||
|
||||
granulated_lds_size = bits(pgm_rsrc2, 15, 23)
|
||||
assert granulated_lds_size == 0 # spec: must be 0; CP uses dispatch packet rounding
|
||||
|
||||
enable_exception_fp_invalid = bits(pgm_rsrc2, 24, 24)
|
||||
enable_exception_fp_denorm_src = bits(pgm_rsrc2, 25, 25)
|
||||
enable_exception_fp_div0 = bits(pgm_rsrc2, 26, 26)
|
||||
enable_exception_fp_overflow = bits(pgm_rsrc2, 27, 27)
|
||||
enable_exception_fp_underflow = bits(pgm_rsrc2, 28, 28)
|
||||
enable_exception_fp_inexact = bits(pgm_rsrc2, 29, 29)
|
||||
enable_exception_int_div0 = bits(pgm_rsrc2, 30, 30)
|
||||
|
||||
assert_zero(pgm_rsrc2, 31, 31)
|
||||
|
||||
print("RSRC2.ENABLE_PRIVATE_SEGMENT:", enable_private_segment)
|
||||
print("RSRC2.USER_SGPR_COUNT:", user_sgpr_count)
|
||||
print("RSRC2.ENABLE_SGPR_WORKGROUP_ID_X:", enable_sgpr_workgroup_id_x)
|
||||
print("RSRC2.ENABLE_SGPR_WORKGROUP_ID_Y:", enable_sgpr_workgroup_id_y)
|
||||
print("RSRC2.ENABLE_SGPR_WORKGROUP_ID_Z:", enable_sgpr_workgroup_id_z)
|
||||
print("RSRC2.ENABLE_SGPR_WORKGROUP_INFO:", enable_sgpr_workgroup_info)
|
||||
print("RSRC2.ENABLE_VGPR_WORKITEM_ID (enum):", enable_vgpr_workitem_id)
|
||||
|
||||
print("RSRC2.EXC_FP_INVALID:", enable_exception_fp_invalid)
|
||||
print("RSRC2.EXC_FP_DENORM_SRC:", enable_exception_fp_denorm_src)
|
||||
print("RSRC2.EXC_FP_DIV0:", enable_exception_fp_div0)
|
||||
print("RSRC2.EXC_FP_OVERFLOW:", enable_exception_fp_overflow)
|
||||
print("RSRC2.EXC_FP_UNDERFLOW:", enable_exception_fp_underflow)
|
||||
print("RSRC2.EXC_FP_INEXACT:", enable_exception_fp_inexact)
|
||||
print("RSRC2.EXC_INT_DIV0:", enable_exception_int_div0)
|
||||
|
||||
# user sgprs
|
||||
|
||||
enable_sgpr_private_segment_buffer = bits(desc, 448, 448)
|
||||
enable_sgpr_dispatch_ptr = bits(desc, 449, 449)
|
||||
enable_sgpr_queue_ptr = bits(desc, 450, 450)
|
||||
enable_sgpr_kernarg_segment_ptr = bits(desc, 451, 451)
|
||||
enable_sgpr_dispatch_id = bits(desc, 452, 452)
|
||||
enable_sgpr_flat_scratch_init = bits(desc, 453, 453)
|
||||
enable_sgpr_private_segment_size = bits(desc, 454, 454)
|
||||
|
||||
assert_zero(desc, 455, 457)
|
||||
|
||||
print("DESC.ENABLE_SGPR_PRIVATE_SEGMENT_BUFFER:", enable_sgpr_private_segment_buffer)
|
||||
print("DESC.ENABLE_SGPR_DISPATCH_PTR:", enable_sgpr_dispatch_ptr)
|
||||
print("DESC.ENABLE_SGPR_QUEUE_PTR:", enable_sgpr_queue_ptr)
|
||||
print("DESC.ENABLE_SGPR_KERNARG_SEGMENT_PTR:", enable_sgpr_kernarg_segment_ptr)
|
||||
print("DESC.ENABLE_SGPR_DISPATCH_ID:", enable_sgpr_dispatch_id)
|
||||
print("DESC.ENABLE_SGPR_FLAT_SCRATCH_INIT:", enable_sgpr_flat_scratch_init)
|
||||
print("DESC.ENABLE_SGPR_PRIVATE_SEGMENT_SIZE:", enable_sgpr_private_segment_size)
|
||||
|
||||
assert_zero(desc, 458, 459)
|
||||
|
||||
uses_dynamic_stack = bits(desc, 459, 460)
|
||||
print("DESC.USES_DYNAMIC_STACK:", uses_dynamic_stack)
|
||||
|
||||
assert_zero(desc, 460, 463)
|
||||
kernarg_preload_spec_length = bits(desc, 464, 470)
|
||||
print("DESC.KERNARG_PRELOAD_SPEC_LENGTH:", kernarg_preload_spec_length)
|
||||
|
||||
kernarg_preload_spec_offset = bits(desc, 471, 479)
|
||||
print("DESC.KERNARG_PRELOAD_SPEC_OFFSET:", kernarg_preload_spec_offset)
|
||||
|
||||
assert_zero(desc, 480, 511)
|
||||
Reference in New Issue
Block a user