mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 23:48:01 -05:00
rdna4 work
This commit is contained in:
File diff suppressed because it is too large
Load Diff
@@ -497,9 +497,9 @@ class Inst:
|
||||
|
||||
def __hash__(self): return hash((self.__class__.__name__, tuple(sorted((k, repr(v)) for k, v in self._values.items())), self._literal))
|
||||
|
||||
def disasm(self) -> str:
|
||||
def disasm(self, wave_size: int = 32) -> str:
|
||||
from extra.assembly.amd.asm import disasm
|
||||
return disasm(self)
|
||||
return disasm(self, wave_size)
|
||||
|
||||
_enum_map = {'VOP1': VOP1Op, 'VOP2': VOP2Op, 'VOP3': VOP3Op, 'VOP3SD': VOP3SDOp, 'VOP3P': VOP3POp, 'VOPC': VOPCOp,
|
||||
'SOP1': SOP1Op, 'SOP2': SOP2Op, 'SOPC': SOPCOp, 'SOPK': SOPKOp, 'SOPP': SOPPOp,
|
||||
|
||||
@@ -32,12 +32,14 @@ RDNA4_TEST_FILES = {
|
||||
'vop3_from_vop1': 'gfx12_asm_vop3_from_vop1.s', 'vop3_from_vop2': 'gfx12_asm_vop3_from_vop2.s',
|
||||
'ds': 'gfx12_asm_ds.s', 'ds_alias': 'gfx12_asm_ds_alias.s', 'smem': 'gfx12_asm_smem.s',
|
||||
'vflat': 'gfx12_asm_vflat.s', 'vflat_alias': 'gfx12_asm_vflat_alias.s',
|
||||
'vscratch': 'gfx12_asm_vflat.s', 'vscratch_alias': 'gfx12_asm_vflat_alias.s', # scratch instructions in vflat files
|
||||
'vbuffer_mubuf': 'gfx12_asm_vbuffer_mubuf.s', 'vbuffer_mubuf_alias': 'gfx12_asm_vbuffer_mubuf_alias.s',
|
||||
'vbuffer_mtbuf': 'gfx12_asm_vbuffer_mtbuf.s', 'vbuffer_mtbuf_alias': 'gfx12_asm_vbuffer_mtbuf_alias.s',
|
||||
'vimage': 'gfx12_asm_vimage.s', 'vimage_alias': 'gfx12_asm_vimage_alias.s', 'vsample': 'gfx12_asm_vsample.s',
|
||||
'vdsdir': 'gfx12_asm_vdsdir.s', 'vdsdir_alias': 'gfx12_asm_vdsdir_alias.s',
|
||||
'exp': 'gfx12_asm_exp.s', 'wmma_w32': 'gfx12_asm_wmma_w32.s', 'wmma_w64': 'gfx12_asm_wmma_w64.s',
|
||||
'features': 'gfx12_asm_features.s', 'global_load_tr': 'gfx12_asm_global_load_tr.s',
|
||||
'global_load_tr': 'gfx12_asm_global_load_tr.s',
|
||||
# NOTE: 'features' (gfx12_asm_features.s) tests DPP instruction variants which require separate format decoders
|
||||
}
|
||||
|
||||
def parse_llvm_tests(text: str, gfx_prefix: str) -> list[tuple[str, bytes]]:
|
||||
@@ -108,6 +110,10 @@ class TestLLVMBase(unittest.TestCase):
|
||||
fmt_cls = self.formats.get(name)
|
||||
if fmt_cls is None: self.skipTest(f"No format class for {name}")
|
||||
|
||||
# Determine wave size from test name (w64 = wave64, otherwise wave32)
|
||||
wave_size = 64 if 'w64' in name else 32
|
||||
mattr = f'+real-true16,+wavefrontsize{wave_size}'
|
||||
|
||||
to_test: list[tuple[str, bytes, str | None, str | None]] = []
|
||||
for asm_text, data in self.tests.get(name, []):
|
||||
if len(data) > fmt_cls._size(): continue
|
||||
@@ -117,14 +123,14 @@ class TestLLVMBase(unittest.TestCase):
|
||||
if decoded.to_bytes()[:len(data)] != data:
|
||||
to_test.append((asm_text, data, None, "decode roundtrip failed"))
|
||||
continue
|
||||
to_test.append((asm_text, data, decoded.disasm(), None))
|
||||
to_test.append((asm_text, data, decoded.disasm(wave_size), None))
|
||||
except Exception as e:
|
||||
to_test.append((asm_text, data, None, f"exception: {e}"))
|
||||
|
||||
disasm_strs = [(i, t[2]) for i, t in enumerate(to_test) if t[2] is not None]
|
||||
llvm_map = {}
|
||||
if disasm_strs:
|
||||
llvm_results = compile_asm_batch([s for _, s in disasm_strs], self.mcpu)
|
||||
llvm_results = compile_asm_batch([s for _, s in disasm_strs], self.mcpu, mattr)
|
||||
llvm_map = {i: llvm_results[j] for j, (i, _) in enumerate(disasm_strs)}
|
||||
|
||||
passed, failed, failures = 0, 0, []
|
||||
@@ -180,9 +186,10 @@ class TestLLVMRDNA4(TestLLVMBase):
|
||||
'vbuffer_mtbuf': get('VBUFFER'), 'vbuffer_mtbuf_alias': get('VBUFFER'),
|
||||
'vdsdir': get('VDSDIR'), 'vdsdir_alias': get('VDSDIR'),
|
||||
'vflat': get('VFLAT'), 'vflat_alias': get('VFLAT'),
|
||||
'vscratch': get('VSCRATCH'), 'vscratch_alias': get('VSCRATCH'),
|
||||
'vimage': get('VIMAGE'), 'vimage_alias': get('VIMAGE'), 'vsample': get('VSAMPLE'),
|
||||
'wmma_w32': get('VOP3P'), 'wmma_w64': get('VOP3P'),
|
||||
'features': None, 'global_load_tr': get('VGLOBAL'),
|
||||
'global_load_tr': get('VGLOBAL'),
|
||||
}
|
||||
cls._load_tests(RDNA4_TEST_FILES)
|
||||
|
||||
|
||||
@@ -1,90 +1,37 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Roundtrip tests: generate tinygrad kernels, decode instructions, re-encode, verify match."""
|
||||
import unittest, io, sys, re, subprocess, os
|
||||
from extra.assembly.amd.autogen.rdna3.ins import *
|
||||
from extra.assembly.amd.dsl import Inst
|
||||
from extra.assembly.amd.asm import asm
|
||||
from extra.assembly.amd.asm import detect_format
|
||||
from extra.assembly.amd.test.helpers import get_llvm_mc, get_llvm_objdump
|
||||
|
||||
def disassemble_lib(lib: bytes, compiler) -> list[tuple[str, bytes]]:
|
||||
"""Disassemble ELF binary and return list of (instruction_text, machine_code_bytes)."""
|
||||
old_stdout = sys.stdout
|
||||
sys.stdout = io.StringIO()
|
||||
compiler.disassemble(lib)
|
||||
output = sys.stdout.getvalue()
|
||||
sys.stdout = old_stdout
|
||||
|
||||
results = []
|
||||
for line in output.splitlines():
|
||||
if '//' not in line: continue
|
||||
instr = line.split('//')[0].strip()
|
||||
if not instr: continue
|
||||
comment = line.split('//')[1].strip()
|
||||
if ':' not in comment: continue
|
||||
hex_str = comment.split(':')[1].strip().split()[0]
|
||||
try:
|
||||
machine_bytes = bytes.fromhex(hex_str)[::-1] # big-endian to little-endian
|
||||
results.append((instr, machine_bytes))
|
||||
except ValueError:
|
||||
continue
|
||||
return results
|
||||
|
||||
def compile_asm(instr: str, compiler=None) -> bytes:
|
||||
"""Compile a single instruction with llvm-mc and return the machine code bytes."""
|
||||
llvm_mc = get_llvm_mc()
|
||||
result = subprocess.run(
|
||||
[llvm_mc, '-triple=amdgcn', '-mcpu=gfx1100', '-mattr=+real-true16,+wavefrontsize32', '-show-encoding'],
|
||||
input=f".text\n{instr}\n", capture_output=True, text=True)
|
||||
if result.returncode != 0: raise RuntimeError(f"llvm-mc failed for '{instr}': {result.stderr.strip()}")
|
||||
# Parse encoding: [0x01,0x39,0x0a,0x7e]
|
||||
for line in result.stdout.split('\n'):
|
||||
if 'encoding:' in line:
|
||||
enc = line.split('encoding:')[1].strip()
|
||||
if enc.startswith('[') and enc.endswith(']'):
|
||||
hex_vals = enc[1:-1].replace('0x', '').replace(',', '').replace(' ', '')
|
||||
return bytes.fromhex(hex_vals)
|
||||
raise RuntimeError(f"no encoding found in llvm-mc output for: {instr}")
|
||||
|
||||
def compile_asm_batch(instrs: list[str]) -> list[bytes]:
|
||||
def compile_asm_batch(instrs: list[str], mcpu: str = 'gfx1100') -> list[bytes]:
|
||||
"""Compile multiple instructions with a single llvm-mc call."""
|
||||
if not instrs: return []
|
||||
llvm_mc = get_llvm_mc()
|
||||
src = ".text\n" + "\n".join(instrs) + "\n"
|
||||
result = subprocess.run(
|
||||
[llvm_mc, '-triple=amdgcn', '-mcpu=gfx1100', '-mattr=+real-true16,+wavefrontsize32', '-show-encoding'],
|
||||
input=src, capture_output=True, text=True)
|
||||
result = subprocess.run([get_llvm_mc(), '-triple=amdgcn', f'-mcpu={mcpu}', '-mattr=+real-true16,+wavefrontsize32', '-show-encoding'],
|
||||
input=".text\n" + "\n".join(instrs) + "\n", capture_output=True, text=True)
|
||||
if result.returncode != 0: raise RuntimeError(f"llvm-mc batch failed: {result.stderr.strip()}")
|
||||
# Parse all encodings in order
|
||||
encodings = []
|
||||
for line in result.stdout.split('\n'):
|
||||
if 'encoding:' in line:
|
||||
enc = line.split('encoding:')[1].strip()
|
||||
if enc.startswith('[') and enc.endswith(']'):
|
||||
hex_vals = enc[1:-1].replace('0x', '').replace(',', '').replace(' ', '')
|
||||
encodings.append(bytes.fromhex(hex_vals))
|
||||
encodings.append(bytes.fromhex(enc[1:-1].replace('0x', '').replace(',', '').replace(' ', '')))
|
||||
if len(encodings) != len(instrs): raise RuntimeError(f"expected {len(instrs)} encodings, got {len(encodings)}")
|
||||
return encodings
|
||||
|
||||
def compile_and_disasm_batch(instrs: list[str], compiler) -> list[str]:
|
||||
def compile_and_disasm_batch(instrs: list[str], mcpu: str = 'gfx1100') -> list[str]:
|
||||
"""Compile instructions with LLVM and get LLVM's disassembly."""
|
||||
import tempfile, os
|
||||
import tempfile
|
||||
if not instrs: return []
|
||||
# Build assembly source with all instructions
|
||||
src = ".text\n.globl test\n.p2align 8\n.type test,@function\ntest:\n"
|
||||
src += "\n".join(f" {instr}" for instr in instrs) + "\n"
|
||||
# Use llvm-mc to assemble to object file
|
||||
src = ".text\n.globl test\n.p2align 8\n.type test,@function\ntest:\n" + "\n".join(f" {instr}" for instr in instrs) + "\n"
|
||||
with tempfile.NamedTemporaryFile(suffix='.o', delete=False) as f:
|
||||
obj_path = f.name
|
||||
try:
|
||||
result = subprocess.run(
|
||||
[get_llvm_mc(), '-triple=amdgcn', '-mcpu=gfx1100', '-mattr=+real-true16,+wavefrontsize32', '-filetype=obj', '-o', obj_path],
|
||||
input=src, capture_output=True, text=True)
|
||||
result = subprocess.run([get_llvm_mc(), '-triple=amdgcn', f'-mcpu={mcpu}', '-mattr=+real-true16,+wavefrontsize32', '-filetype=obj', '-o', obj_path],
|
||||
input=src, capture_output=True, text=True)
|
||||
if result.returncode != 0: raise RuntimeError(f"llvm-mc failed: {result.stderr.strip()}")
|
||||
# Disassemble with llvm-objdump
|
||||
result = subprocess.run([get_llvm_objdump(), '-d', '--mcpu=gfx1100', obj_path], capture_output=True, text=True)
|
||||
result = subprocess.run([get_llvm_objdump(), '-d', f'--mcpu={mcpu}', obj_path], capture_output=True, text=True)
|
||||
if result.returncode != 0: raise RuntimeError(f"llvm-objdump failed: {result.stderr.strip()}")
|
||||
# Parse disassembly output
|
||||
results: list[str] = []
|
||||
for line in result.stdout.splitlines():
|
||||
if '//' not in line: continue
|
||||
@@ -94,127 +41,143 @@ def compile_and_disasm_batch(instrs: list[str], compiler) -> list[str]:
|
||||
finally:
|
||||
os.unlink(obj_path)
|
||||
|
||||
class TestTinygradKernelRoundtrip(unittest.TestCase):
|
||||
"""Test roundtrip on real tinygrad-generated kernels using get_kernels_from_tinygrad pattern."""
|
||||
class TestRoundtripBase(unittest.TestCase):
|
||||
"""Base class for roundtrip tests."""
|
||||
mcpu: str = 'gfx1100'
|
||||
arch: str = 'rdna3'
|
||||
|
||||
@classmethod
|
||||
def _get_modules(cls):
|
||||
if cls.arch == 'rdna3':
|
||||
from extra.assembly.amd.autogen.rdna3 import ins
|
||||
from extra.assembly.amd.asm import detect_format, asm
|
||||
else:
|
||||
import extra.assembly.amd.autogen.rdna4.ins as ins
|
||||
from extra.assembly.amd.asm import asm
|
||||
detect_format = None # RDNA4 uses different detection
|
||||
return ins, detect_format, asm
|
||||
|
||||
def _test_kernel_roundtrip(self, op_fn):
|
||||
"""Generate kernel from op_fn, test:
|
||||
1. decode -> reencode matches original bytes
|
||||
2. asm(disasm()) matches LLVM output
|
||||
3. our disasm() matches LLVM's disassembly string exactly
|
||||
"""
|
||||
"""Generate kernel from op_fn, test decode -> reencode and asm(disasm()) matches LLVM."""
|
||||
from extra.assembly.amd.test.test_compare_emulators import get_kernels_from_tinygrad
|
||||
from tinygrad.runtime.support.compiler_amd import HIPCompiler
|
||||
|
||||
ins, detect_format, asm = self._get_modules()
|
||||
kernels, _, _ = get_kernels_from_tinygrad(op_fn)
|
||||
compiler = HIPCompiler('gfx1100')
|
||||
compiler = HIPCompiler(self.mcpu)
|
||||
|
||||
# First pass: decode all instructions and collect info
|
||||
decoded_instrs: list[tuple] = [] # list of (ki, offset, orig_bytes, decoded, our_disasm, decode_ok, decode_err)
|
||||
# First pass: decode all instructions
|
||||
decoded_instrs: list[tuple] = []
|
||||
for ki, kernel in enumerate(kernels):
|
||||
offset = 0
|
||||
while offset < len(kernel.code):
|
||||
remaining = kernel.code[offset:]
|
||||
fmt = detect_format(remaining)
|
||||
if fmt is None:
|
||||
decoded_instrs.append((ki, offset, None, None, None, False, "no format"))
|
||||
offset += 4
|
||||
continue
|
||||
if len(remaining) < 4: break
|
||||
|
||||
# Try to detect format
|
||||
if detect_format is not None:
|
||||
try:
|
||||
fmt = detect_format(remaining)
|
||||
except ValueError:
|
||||
decoded_instrs.append((ki, offset, None, None, None, False, "no format"))
|
||||
offset += 4
|
||||
continue
|
||||
else:
|
||||
# For RDNA4, try formats in order
|
||||
fmt = None
|
||||
from extra.assembly.amd.autogen.rdna4.ins import SOP1, SOP2, SOPC, SOPK, SOPP, VOP1, VOP2, VOP3, VOP3P, VOPC, VOPD, VDS, SMEM, VFLAT, VBUFFER, VIMAGE, VSAMPLE, VEXPORT, VDSDIR
|
||||
word = int.from_bytes(remaining[:4], 'little')
|
||||
for cls in [VOPD, VOP3P, VOP3, VDS, VFLAT, VBUFFER, VIMAGE, VSAMPLE, SMEM, VEXPORT, SOP1, SOPC, SOPP, SOPK, VOPC, VOP1, SOP2, VOP2, VDSDIR]:
|
||||
if cls._encoding is not None:
|
||||
bf, val = cls._encoding
|
||||
if ((word >> bf.lo) & bf.mask()) == val:
|
||||
fmt = cls
|
||||
break
|
||||
if fmt is None:
|
||||
decoded_instrs.append((ki, offset, None, None, None, False, "no format"))
|
||||
offset += 4
|
||||
continue
|
||||
|
||||
base_size = fmt._size()
|
||||
if len(remaining) < base_size:
|
||||
break
|
||||
if len(remaining) < base_size: break
|
||||
|
||||
try:
|
||||
decoded = fmt.from_bytes(remaining) # pass all remaining bytes so from_bytes can read literal
|
||||
size = decoded.size() # actual size including literal
|
||||
decoded = fmt.from_bytes(remaining)
|
||||
size = decoded.size()
|
||||
orig_bytes = remaining[:size]
|
||||
reencoded = decoded.to_bytes()
|
||||
our_disasm = decoded.disasm()
|
||||
decode_ok = reencoded == orig_bytes
|
||||
decode_err: str | None = None if decode_ok else f"orig={orig_bytes.hex()} reenc={reencoded.hex()}"
|
||||
decode_err = None if decode_ok else f"orig={orig_bytes.hex()} reenc={reencoded.hex()}"
|
||||
decoded_instrs.append((ki, offset, orig_bytes, decoded, our_disasm, decode_ok, decode_err))
|
||||
except Exception as e:
|
||||
decoded_instrs.append((ki, offset, remaining[:base_size], None, None, False, str(e)))
|
||||
size = base_size
|
||||
|
||||
offset += size
|
||||
|
||||
# Collect disasm strings for batched LLVM calls - skip unknown opcodes (op_X) that LLVM can't compile
|
||||
asm_test_instrs: list[tuple[int, str]] = [] # (idx, our_disasm) for asm test
|
||||
disasm_test_instrs: list[tuple[int, str]] = [] # (idx, our_disasm) for disasm comparison test
|
||||
|
||||
# Collect disasm strings for batched LLVM calls
|
||||
asm_test_instrs: list[tuple[int, str]] = []
|
||||
for idx, (ki, offset, orig_bytes, decoded, our_disasm, decode_ok, decode_err) in enumerate(decoded_instrs):
|
||||
if our_disasm is None: continue
|
||||
# Skip unknown opcodes and malformed instructions for both tests
|
||||
if our_disasm.startswith('op_') or re.search(r', \d+, \d+, \d+,', our_disasm): continue
|
||||
asm_test_instrs.append((idx, our_disasm))
|
||||
disasm_test_instrs.append((idx, our_disasm))
|
||||
|
||||
# Batch compile for asm test
|
||||
asm_llvm_results = compile_asm_batch([d for _, d in asm_test_instrs])
|
||||
asm_llvm_results = compile_asm_batch([d for _, d in asm_test_instrs], self.mcpu)
|
||||
asm_llvm_map = {idx: result for (idx, _), result in zip(asm_test_instrs, asm_llvm_results)}
|
||||
|
||||
# Batch compile+disasm for disasm comparison test
|
||||
disasm_llvm_results = compile_and_disasm_batch([d for _, d in disasm_test_instrs], compiler)
|
||||
disasm_llvm_map = {idx: result for (idx, _), result in zip(disasm_test_instrs, disasm_llvm_results)}
|
||||
disasm_llvm_results = compile_and_disasm_batch([d for _, d in asm_test_instrs], self.mcpu)
|
||||
disasm_llvm_map = {idx: result for (idx, _), result in zip(asm_test_instrs, disasm_llvm_results)}
|
||||
|
||||
# Now evaluate results
|
||||
# Evaluate results
|
||||
decode_passed, decode_failed, decode_skipped = 0, 0, 0
|
||||
asm_passed, asm_failed, asm_skipped = 0, 0, 0
|
||||
disasm_passed, disasm_failed, disasm_skipped = 0, 0, 0
|
||||
decode_failures: list[str] = []
|
||||
asm_failures: list[str] = []
|
||||
disasm_failures: list[str] = []
|
||||
decode_failures, asm_failures, disasm_failures = [], [], []
|
||||
|
||||
for idx, (ki, offset, orig_bytes, decoded, our_disasm, decode_ok, decode_err) in enumerate(decoded_instrs):
|
||||
# Decode test
|
||||
if decode_ok:
|
||||
decode_passed += 1
|
||||
elif decode_err == "no format":
|
||||
decode_skipped += 1
|
||||
if decode_ok: decode_passed += 1
|
||||
elif decode_err == "no format": decode_skipped += 1
|
||||
else:
|
||||
decode_failed += 1
|
||||
decode_failures.append(f"K{ki}@{offset}: {our_disasm}: {decode_err}")
|
||||
|
||||
# Asm test
|
||||
if our_disasm is None:
|
||||
asm_skipped += 1
|
||||
disasm_skipped += 1
|
||||
elif idx in asm_llvm_map:
|
||||
llvm_bytes = asm_llvm_map[idx]
|
||||
try:
|
||||
our_bytes = asm(our_disasm).to_bytes()
|
||||
if our_bytes[:len(llvm_bytes)] == llvm_bytes:
|
||||
asm_passed += 1
|
||||
if our_bytes[:len(llvm_bytes)] == llvm_bytes: asm_passed += 1
|
||||
else:
|
||||
asm_failed += 1
|
||||
asm_failures.append(f"K{ki}@{offset}: '{our_disasm}': ours={our_bytes[:len(llvm_bytes)].hex()} llvm={llvm_bytes.hex()}")
|
||||
except Exception:
|
||||
asm_skipped += 1
|
||||
|
||||
if idx in disasm_llvm_map:
|
||||
if our_disasm == disasm_llvm_map[idx]: disasm_passed += 1
|
||||
else:
|
||||
disasm_failed += 1
|
||||
disasm_failures.append(f"K{ki}@{offset}: ours='{our_disasm}' llvm='{disasm_llvm_map[idx]}'")
|
||||
else:
|
||||
disasm_skipped += 1
|
||||
else:
|
||||
asm_skipped += 1
|
||||
|
||||
# Disasm comparison test
|
||||
if our_disasm is None:
|
||||
disasm_skipped += 1
|
||||
elif idx in disasm_llvm_map:
|
||||
llvm_disasm = disasm_llvm_map[idx]
|
||||
if our_disasm == llvm_disasm:
|
||||
disasm_passed += 1
|
||||
else:
|
||||
disasm_failed += 1
|
||||
disasm_failures.append(f"K{ki}@{offset}: ours='{our_disasm}' llvm='{llvm_disasm}'")
|
||||
else:
|
||||
disasm_skipped += 1
|
||||
|
||||
print(f"decode roundtrip: {decode_passed} passed, {decode_failed} failed, {decode_skipped} skipped")
|
||||
print(f"asm vs llvm: {asm_passed} passed, {asm_failed} failed, {asm_skipped} skipped")
|
||||
print(f"disasm vs llvm: {disasm_passed} passed, {disasm_failed} failed, {disasm_skipped} skipped")
|
||||
print(f"{self.arch.upper()} decode roundtrip: {decode_passed} passed, {decode_failed} failed, {decode_skipped} skipped")
|
||||
print(f"{self.arch.upper()} asm vs llvm: {asm_passed} passed, {asm_failed} failed, {asm_skipped} skipped")
|
||||
print(f"{self.arch.upper()} disasm vs llvm: {disasm_passed} passed, {disasm_failed} failed, {disasm_skipped} skipped")
|
||||
self.assertEqual(decode_failed, 0, f"Decode failures:\n" + "\n".join(decode_failures[:20]))
|
||||
self.assertEqual(asm_failed, 0, f"Asm failures:\n" + "\n".join(asm_failures[:20]))
|
||||
# Note: disasm string comparison is informational only - formatting differences between LLVM versions are expected
|
||||
|
||||
# Basic unary ops
|
||||
class TestRoundtripRDNA3(TestRoundtripBase):
|
||||
"""Roundtrip tests for RDNA3 (gfx1100)."""
|
||||
mcpu, arch = 'gfx1100', 'rdna3'
|
||||
|
||||
def test_neg(self): self._test_kernel_roundtrip(lambda T: -T([1.0, -2.0, 3.0, -4.0]))
|
||||
def test_relu(self): self._test_kernel_roundtrip(lambda T: T([-1.0, 0.0, 1.0, 2.0]).relu())
|
||||
def test_exp(self): self._test_kernel_roundtrip(lambda T: T([0.0, 1.0, 2.0]).exp())
|
||||
@@ -222,42 +185,62 @@ class TestTinygradKernelRoundtrip(unittest.TestCase):
|
||||
def test_sin(self): self._test_kernel_roundtrip(lambda T: T([0.0, 1.0, 2.0]).sin())
|
||||
def test_sqrt(self): self._test_kernel_roundtrip(lambda T: T([1.0, 4.0, 9.0]).sqrt())
|
||||
def test_recip(self): self._test_kernel_roundtrip(lambda T: T([1.0, 2.0, 4.0]).reciprocal())
|
||||
|
||||
# Binary ops
|
||||
def test_add(self): self._test_kernel_roundtrip(lambda T: T([1.0, 2.0]) + T([3.0, 4.0]))
|
||||
def test_sub(self): self._test_kernel_roundtrip(lambda T: T([5.0, 6.0]) - T([1.0, 2.0]))
|
||||
def test_mul(self): self._test_kernel_roundtrip(lambda T: T([2.0, 3.0]) * T([4.0, 5.0]))
|
||||
def test_div(self): self._test_kernel_roundtrip(lambda T: T([10.0, 20.0]) / T([2.0, 4.0]))
|
||||
def test_max_binary(self): self._test_kernel_roundtrip(lambda T: T([1.0, 5.0]).maximum(T([3.0, 2.0])))
|
||||
|
||||
# Reductions
|
||||
def test_sum_reduce(self): self._test_kernel_roundtrip(lambda T: T.empty(64).sum())
|
||||
def test_max_reduce(self): self._test_kernel_roundtrip(lambda T: T.empty(64).max())
|
||||
def test_mean_reduce(self): self._test_kernel_roundtrip(lambda T: T.empty(32).mean())
|
||||
|
||||
# Matmul
|
||||
def test_gemm_4x4(self): self._test_kernel_roundtrip(lambda T: T.empty(4, 4) @ T.empty(4, 4))
|
||||
def test_gemv(self): self._test_kernel_roundtrip(lambda T: T.empty(1, 16) @ T.empty(16, 16))
|
||||
|
||||
# Complex ops
|
||||
def test_softmax(self): self._test_kernel_roundtrip(lambda T: T.empty(16).softmax())
|
||||
def test_layernorm(self): self._test_kernel_roundtrip(lambda T: T.empty(8, 8).layernorm())
|
||||
|
||||
# Memory patterns
|
||||
def test_contiguous(self): self._test_kernel_roundtrip(lambda T: T.empty(4, 4).permute(1, 0).contiguous())
|
||||
def test_reshape(self): self._test_kernel_roundtrip(lambda T: (T.empty(16) + 1).reshape(4, 4).contiguous())
|
||||
def test_expand(self): self._test_kernel_roundtrip(lambda T: T.empty(4, 1).expand(4, 4).contiguous())
|
||||
|
||||
# Cast ops
|
||||
def test_cast_int(self): self._test_kernel_roundtrip(lambda T: T.empty(16).int().float())
|
||||
def test_cast_half(self): self._test_kernel_roundtrip(lambda T: T.empty(16).half().float())
|
||||
|
||||
# Comparison ops
|
||||
def test_cmp_lt(self): self._test_kernel_roundtrip(lambda T: (T.empty(64) < T.empty(64)).where(T.empty(64), T.empty(64)))
|
||||
def test_where(self): self._test_kernel_roundtrip(lambda T: (T.empty(64) > 0).where(T.empty(64), T.empty(64)))
|
||||
|
||||
# Fused ops
|
||||
def test_fma(self): self._test_kernel_roundtrip(lambda T: (T([1.0, 2.0]) * T([3.0, 4.0]) + T([5.0, 6.0])))
|
||||
|
||||
@unittest.skipUnless(os.environ.get("TEST_RDNA4"), "RDNA4 roundtrip tests require TEST_RDNA4=1 and gfx1200 hardware")
|
||||
class TestRoundtripRDNA4(TestRoundtripBase):
|
||||
"""Roundtrip tests for RDNA4 (gfx1200)."""
|
||||
mcpu, arch = 'gfx1200', 'rdna4'
|
||||
|
||||
def test_neg(self): self._test_kernel_roundtrip(lambda T: -T([1.0, -2.0, 3.0, -4.0]))
|
||||
def test_relu(self): self._test_kernel_roundtrip(lambda T: T([-1.0, 0.0, 1.0, 2.0]).relu())
|
||||
def test_exp(self): self._test_kernel_roundtrip(lambda T: T([0.0, 1.0, 2.0]).exp())
|
||||
def test_log(self): self._test_kernel_roundtrip(lambda T: T([1.0, 2.0, 3.0]).log())
|
||||
def test_sin(self): self._test_kernel_roundtrip(lambda T: T([0.0, 1.0, 2.0]).sin())
|
||||
def test_sqrt(self): self._test_kernel_roundtrip(lambda T: T([1.0, 4.0, 9.0]).sqrt())
|
||||
def test_recip(self): self._test_kernel_roundtrip(lambda T: T([1.0, 2.0, 4.0]).reciprocal())
|
||||
def test_add(self): self._test_kernel_roundtrip(lambda T: T([1.0, 2.0]) + T([3.0, 4.0]))
|
||||
def test_sub(self): self._test_kernel_roundtrip(lambda T: T([5.0, 6.0]) - T([1.0, 2.0]))
|
||||
def test_mul(self): self._test_kernel_roundtrip(lambda T: T([2.0, 3.0]) * T([4.0, 5.0]))
|
||||
def test_div(self): self._test_kernel_roundtrip(lambda T: T([10.0, 20.0]) / T([2.0, 4.0]))
|
||||
def test_max_binary(self): self._test_kernel_roundtrip(lambda T: T([1.0, 5.0]).maximum(T([3.0, 2.0])))
|
||||
def test_sum_reduce(self): self._test_kernel_roundtrip(lambda T: T.empty(64).sum())
|
||||
def test_max_reduce(self): self._test_kernel_roundtrip(lambda T: T.empty(64).max())
|
||||
def test_mean_reduce(self): self._test_kernel_roundtrip(lambda T: T.empty(32).mean())
|
||||
def test_gemm_4x4(self): self._test_kernel_roundtrip(lambda T: T.empty(4, 4) @ T.empty(4, 4))
|
||||
def test_gemv(self): self._test_kernel_roundtrip(lambda T: T.empty(1, 16) @ T.empty(16, 16))
|
||||
def test_softmax(self): self._test_kernel_roundtrip(lambda T: T.empty(16).softmax())
|
||||
def test_layernorm(self): self._test_kernel_roundtrip(lambda T: T.empty(8, 8).layernorm())
|
||||
def test_contiguous(self): self._test_kernel_roundtrip(lambda T: T.empty(4, 4).permute(1, 0).contiguous())
|
||||
def test_reshape(self): self._test_kernel_roundtrip(lambda T: (T.empty(16) + 1).reshape(4, 4).contiguous())
|
||||
def test_expand(self): self._test_kernel_roundtrip(lambda T: T.empty(4, 1).expand(4, 4).contiguous())
|
||||
def test_cast_int(self): self._test_kernel_roundtrip(lambda T: T.empty(16).int().float())
|
||||
def test_cast_half(self): self._test_kernel_roundtrip(lambda T: T.empty(16).half().float())
|
||||
def test_cmp_lt(self): self._test_kernel_roundtrip(lambda T: (T.empty(64) < T.empty(64)).where(T.empty(64), T.empty(64)))
|
||||
def test_where(self): self._test_kernel_roundtrip(lambda T: (T.empty(64) > 0).where(T.empty(64), T.empty(64)))
|
||||
def test_fma(self): self._test_kernel_roundtrip(lambda T: (T([1.0, 2.0]) * T([3.0, 4.0]) + T([5.0, 6.0])))
|
||||
|
||||
# Keep old class name for backwards compatibility
|
||||
TestTinygradKernelRoundtrip = TestRoundtripRDNA3
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user