mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-07 22:23:55 -05:00
* assembly/amd: add pcode ds ops * refactors * fix ds op * update autogen * fix flat bug * more tests * fix emu test * that's a hack * generic * fix all tests * two tests * fix test failure * better * remove __all__
403 lines
19 KiB
Python
403 lines
19 KiB
Python
# Test to compare Python and Rust RDNA3 emulators by running real tinygrad kernels
|
|
import unittest, ctypes, os
|
|
from dataclasses import dataclass
|
|
from pathlib import Path
|
|
|
|
# Set environment before any tinygrad imports to use MOCKGPU
|
|
# This allows generating AMD GPU kernels without requiring real hardware
|
|
os.environ["AMD"] = "1"
|
|
os.environ["MOCKGPU"] = "1"
|
|
os.environ["PYTHON_REMU"] = "1"
|
|
|
|
from extra.assembly.amd.emu import WaveState, decode_program, step_wave, WAVE_SIZE, set_valid_mem_ranges, LDSMem
|
|
from extra.assembly.amd.test.helpers import KernelInfo
|
|
|
|
REMU_PATH = Path(__file__).parents[3] / "remu/target/release/libremu.so"
|
|
|
|
def _is_f32_nan(bits: int) -> bool:
|
|
"""Check if 32-bit value is a NaN (exponent all 1s, mantissa non-zero)."""
|
|
return (bits & 0x7f800000) == 0x7f800000 and (bits & 0x007fffff) != 0
|
|
|
|
def _vals_equal(a: int, b: int) -> bool:
|
|
"""Compare two 32-bit values, treating all NaN bit patterns as equal."""
|
|
if a == b: return True
|
|
return _is_f32_nan(a) and _is_f32_nan(b)
|
|
|
|
@dataclass
|
|
class StateSnapshot:
|
|
pc: int
|
|
scc: int
|
|
vcc: int
|
|
exec_mask: int
|
|
sgpr: list[int]
|
|
vgpr: list[list[int]]
|
|
|
|
def diff(self, other: 'StateSnapshot', n_lanes: int, arrow: str = " vs ") -> list[str]:
|
|
"""Return list of differences between two states."""
|
|
diffs = []
|
|
if self.pc != other.pc: diffs.append(f"pc: {self.pc}{arrow}{other.pc}")
|
|
if self.scc != other.scc: diffs.append(f"scc: {self.scc}{arrow}{other.scc}")
|
|
if self.vcc != other.vcc: diffs.append(f"vcc: 0x{self.vcc:08x}{arrow}0x{other.vcc:08x}")
|
|
if self.exec_mask != other.exec_mask: diffs.append(f"exec: 0x{self.exec_mask:08x}{arrow}0x{other.exec_mask:08x}")
|
|
for i, (a, b) in enumerate(zip(self.sgpr, other.sgpr)):
|
|
# Skip VCC_LO/HI (106/107) and EXEC_LO/HI (126/127) as they alias vcc/exec_mask which are compared separately
|
|
if i in (106, 107, 126, 127): continue
|
|
if not _vals_equal(a, b): diffs.append(f"sgpr[{i}]: 0x{a:08x}{arrow}0x{b:08x}")
|
|
for lane in range(n_lanes):
|
|
for i, (a, b) in enumerate(zip(self.vgpr[lane], other.vgpr[lane])):
|
|
if not _vals_equal(a, b): diffs.append(f"vgpr[{lane}][{i}]: 0x{a:08x}{arrow}0x{b:08x}")
|
|
return diffs
|
|
|
|
class CStateSnapshot(ctypes.Structure):
|
|
_fields_ = [("pc", ctypes.c_uint32), ("scc", ctypes.c_uint32), ("vcc", ctypes.c_uint32), ("exec_mask", ctypes.c_uint32),
|
|
("sgpr", ctypes.c_uint32 * 128), ("vgpr", (ctypes.c_uint32 * 256) * 32)]
|
|
|
|
def to_snapshot(self) -> StateSnapshot:
|
|
return StateSnapshot(pc=self.pc, scc=self.scc, vcc=self.vcc, exec_mask=self.exec_mask,
|
|
sgpr=list(self.sgpr), vgpr=[list(self.vgpr[i]) for i in range(32)])
|
|
|
|
class RustEmulator:
|
|
def __init__(self):
|
|
self.lib = ctypes.CDLL(str(REMU_PATH))
|
|
self.lib.wave_create.argtypes = [ctypes.c_void_p, ctypes.c_uint32, ctypes.c_uint32]
|
|
self.lib.wave_create.restype = ctypes.c_void_p
|
|
self.lib.wave_step.argtypes = [ctypes.c_void_p]
|
|
self.lib.wave_step.restype = ctypes.c_int32
|
|
self.lib.wave_get_snapshot.argtypes = [ctypes.c_void_p, ctypes.POINTER(CStateSnapshot)]
|
|
self.lib.wave_set_sgpr.argtypes = [ctypes.c_void_p, ctypes.c_uint32, ctypes.c_uint32]
|
|
self.lib.wave_set_vgpr.argtypes = [ctypes.c_void_p, ctypes.c_uint32, ctypes.c_uint32, ctypes.c_uint32]
|
|
self.lib.wave_init_lds.argtypes = [ctypes.c_void_p, ctypes.c_uint32]
|
|
self.lib.wave_free.argtypes = [ctypes.c_void_p]
|
|
self.ctx = None
|
|
|
|
def create(self, kernel: bytes, n_lanes: int):
|
|
kernel_buf = (ctypes.c_char * len(kernel)).from_buffer_copy(kernel)
|
|
self.ctx = self.lib.wave_create(ctypes.addressof(kernel_buf), len(kernel), n_lanes)
|
|
self._kernel_buf = kernel_buf
|
|
|
|
def step(self) -> int: return self.lib.wave_step(self.ctx)
|
|
def set_sgpr(self, idx: int, val: int): self.lib.wave_set_sgpr(self.ctx, idx, val)
|
|
def set_vgpr(self, lane: int, idx: int, val: int): self.lib.wave_set_vgpr(self.ctx, lane, idx, val)
|
|
def init_lds(self, size: int): self.lib.wave_init_lds(self.ctx, size)
|
|
|
|
def get_snapshot(self) -> StateSnapshot:
|
|
snap = CStateSnapshot()
|
|
self.lib.wave_get_snapshot(self.ctx, ctypes.byref(snap))
|
|
return snap.to_snapshot()
|
|
|
|
def free(self):
|
|
if self.ctx: self.lib.wave_free(self.ctx); self.ctx = None
|
|
|
|
class PythonEmulator:
|
|
def __init__(self):
|
|
self.state: WaveState | None = None
|
|
self.program: dict | None = None
|
|
self.lds: bytearray | None = None
|
|
self.n_lanes = 0
|
|
|
|
def create(self, kernel: bytes, n_lanes: int):
|
|
self.program = decode_program(kernel)
|
|
self.state = WaveState()
|
|
self.state.exec_mask = (1 << n_lanes) - 1
|
|
self.lds = LDSMem(bytearray(65536))
|
|
self.n_lanes = n_lanes
|
|
|
|
def step(self) -> int:
|
|
assert self.program is not None and self.state is not None and self.lds is not None
|
|
return step_wave(self.program, self.state, self.lds, self.n_lanes)
|
|
def set_sgpr(self, idx: int, val: int):
|
|
assert self.state is not None
|
|
self.state.sgpr[idx] = val & 0xffffffff
|
|
def set_vgpr(self, lane: int, idx: int, val: int):
|
|
assert self.state is not None
|
|
self.state.vgpr[lane][idx] = val & 0xffffffff
|
|
|
|
def get_snapshot(self) -> StateSnapshot:
|
|
assert self.state is not None
|
|
return StateSnapshot(pc=self.state.pc, scc=self.state.scc, vcc=self.state.vcc & 0xffffffff,
|
|
exec_mask=self.state.exec_mask & 0xffffffff, sgpr=list(self.state.sgpr),
|
|
vgpr=[list(self.state.vgpr[i]) for i in range(WAVE_SIZE)])
|
|
|
|
def run_single_kernel(kernel: bytes, n_lanes: int, args_ptr: int, global_size: tuple[int, int, int],
|
|
program, max_steps: int, debug: bool, trace_len: int, kernel_idx: int = 0,
|
|
max_workgroups: int = 8) -> tuple[bool, str, int]:
|
|
"""Run a single kernel through both emulators. Returns (success, message, total_steps)."""
|
|
gx, gy, gz = global_size
|
|
total_steps = 0
|
|
wg_count = 0
|
|
|
|
for gidz in range(gz):
|
|
for gidy in range(gy):
|
|
for gidx in range(gx):
|
|
if wg_count >= max_workgroups: return True, f"Completed {wg_count} workgroups (limit reached)", total_steps
|
|
wg_count += 1
|
|
rust = RustEmulator()
|
|
python = PythonEmulator()
|
|
rust.create(kernel, n_lanes)
|
|
python.create(kernel, n_lanes)
|
|
|
|
# Initialize LDS (64KB, standard size for AMD GPUs)
|
|
rust.init_lds(65536)
|
|
|
|
for emu in (rust, python):
|
|
emu.set_sgpr(0, args_ptr & 0xffffffff)
|
|
emu.set_sgpr(1, (args_ptr >> 32) & 0xffffffff)
|
|
emu.set_sgpr(13, gidx)
|
|
emu.set_sgpr(14, gidy)
|
|
emu.set_sgpr(15, gidz)
|
|
|
|
step = 0
|
|
trace: list[tuple[int, int, str, StateSnapshot, StateSnapshot]] = []
|
|
try:
|
|
while step < max_steps:
|
|
rust_before = rust.get_snapshot()
|
|
python_before = python.get_snapshot()
|
|
|
|
inst = program.get(python_before.pc)
|
|
inst_str = inst.disasm() if inst else f"unknown at PC={python_before.pc}"
|
|
trace.append((step, python_before.pc, inst_str, rust_before, python_before))
|
|
if len(trace) > trace_len: trace.pop(0)
|
|
|
|
if debug: print(f"K{kernel_idx} WG({gidx},{gidy},{gidz}) Step {step}: PC={python_before.pc}, inst={inst_str}")
|
|
|
|
# Instructions with known Rust emulator bugs - sync Python to Rust after execution
|
|
# v_div_scale/v_div_fixup: Rust has different VCC handling
|
|
# v_cvt_f16_f32: Rust clears high 16 bits, but hardware (and Python) preserves them
|
|
sync_after = any(x in inst_str for x in ('v_div_scale_f32', 'v_div_scale_f64', 'v_div_fixup_f32', 'v_div_fixup_f64',
|
|
'v_cvt_f16_f32'))
|
|
diffs = rust_before.diff(python_before, n_lanes)
|
|
if diffs:
|
|
trace_lines = []
|
|
for idx, (s, pc, d, rb, pb) in enumerate(trace):
|
|
trace_lines.append(f" step {s}: PC={pc:3d} {d}")
|
|
if idx < len(trace) - 1:
|
|
next_rb, next_pb = trace[idx + 1][3:5]
|
|
rust_diffs = rb.diff(next_rb, n_lanes, "->")
|
|
python_diffs = pb.diff(next_pb, n_lanes, "->")
|
|
if rust_diffs: trace_lines.append(f" rust: {', '.join(rust_diffs[:5])}")
|
|
if python_diffs: trace_lines.append(f" python: {', '.join(python_diffs[:5])}")
|
|
elif rust_diffs: trace_lines.append(f" python: (no changes)")
|
|
else:
|
|
# Last traced instruction - compare with current state
|
|
rust_diffs = rb.diff(rust_before, n_lanes, "->")
|
|
python_diffs = pb.diff(python_before, n_lanes, "->")
|
|
if rust_diffs: trace_lines.append(f" rust: {', '.join(rust_diffs[:5])}")
|
|
if python_diffs: trace_lines.append(f" python: {', '.join(python_diffs[:5])}")
|
|
elif rust_diffs: trace_lines.append(f" python: (no changes)")
|
|
trace_str = "\n".join(trace_lines)
|
|
return False, f"K{kernel_idx} WG({gidx},{gidy},{gidz}) Step {step} before inst '{inst_str}': states differ (rust vs python):\n " + "\n ".join(diffs[:10]) + f"\n Recent instructions:\n{trace_str}", total_steps
|
|
|
|
rust_result = rust.step()
|
|
python_result = python.step()
|
|
|
|
if rust_result != python_result:
|
|
# Rust returns 1 for unsupported instructions - skip test
|
|
if rust_result == 1 and python_result == 0:
|
|
raise unittest.SkipTest(f"Rust emulator doesn't support instruction: {inst_str}")
|
|
trace_str = "\n".join(f" step {s}: PC={pc:3d} {d}" for s, pc, d, _, _ in trace)
|
|
return False, f"K{kernel_idx} WG({gidx},{gidy},{gidz}) Step {step}: different return codes: rust={rust_result}, python={python_result}, inst={inst_str}\n Recent instructions:\n{trace_str}", total_steps
|
|
|
|
# Sync Python state to Rust after instructions with known Rust emulator differences
|
|
if sync_after:
|
|
rust_after = rust.get_snapshot()
|
|
for i in range(128): python.set_sgpr(i, rust_after.sgpr[i])
|
|
for lane in range(n_lanes):
|
|
for i in range(256): python.set_vgpr(lane, i, rust_after.vgpr[lane][i])
|
|
assert python.state is not None
|
|
python.state.pc, python.state.scc, python.state.vcc, python.state.exec_mask = rust_after.pc, rust_after.scc, rust_after.vcc, rust_after.exec_mask
|
|
|
|
if rust_result == -1:
|
|
total_steps += step + 1
|
|
break
|
|
if rust_result == 1:
|
|
total_steps += step + 1
|
|
break
|
|
if rust_result < 0 and rust_result != -2:
|
|
return False, f"K{kernel_idx} WG({gidx},{gidy},{gidz}) Step {step}: error code {rust_result}", total_steps
|
|
|
|
step += 1
|
|
else:
|
|
return False, f"K{kernel_idx} WG({gidx},{gidy},{gidz}) Max steps ({max_steps}) reached", total_steps
|
|
finally:
|
|
rust.free()
|
|
|
|
return True, f"Completed {gx*gy*gz} workgroups", total_steps
|
|
|
|
def compare_emulators_multi_kernel(kernels: list[KernelInfo], buf_pool: dict[int, int], max_steps: int = 1000,
|
|
debug: bool = False, trace_len: int = 10, buf_data: dict[int, bytes] | None = None) -> tuple[bool, str]:
|
|
"""Run all kernels through both emulators with shared buffer pool."""
|
|
if buf_data is None: buf_data = {}
|
|
|
|
# Allocate shared buffer pool with padding for over-reads (GPU loads up to 16 bytes at once)
|
|
buf_id_to_ptr: dict[int, int] = {}
|
|
buffers = []
|
|
for buf_id, size in buf_pool.items():
|
|
padded_size = ((size + 15) // 16) * 16 + 16 # round up to 16 bytes + extra padding
|
|
# Initialize with data from COPY if available
|
|
init_data = buf_data.get(buf_id, b'\x00' * padded_size)
|
|
init_list = list(init_data) + [0] * (padded_size - len(init_data))
|
|
buf = (ctypes.c_uint8 * padded_size)(*init_list[:padded_size])
|
|
buffers.append((buf, padded_size))
|
|
buf_id_to_ptr[buf_id] = ctypes.addressof(buf)
|
|
|
|
# Set up valid memory ranges
|
|
ranges = {(ctypes.addressof(b), size) for b, size in buffers}
|
|
|
|
total_steps = 0
|
|
for ki, kernel in enumerate(kernels):
|
|
# Create args array for this kernel's buffers
|
|
args = (ctypes.c_uint64 * len(kernel.buf_idxs))(*[buf_id_to_ptr[bid] for bid in kernel.buf_idxs])
|
|
args_ptr = ctypes.addressof(args)
|
|
|
|
# Update valid ranges to include this args array
|
|
kernel_ranges = ranges | {(args_ptr, ctypes.sizeof(args))}
|
|
set_valid_mem_ranges(kernel_ranges)
|
|
|
|
program = decode_program(kernel.code)
|
|
n_lanes = kernel.local_size[0] * kernel.local_size[1] * kernel.local_size[2]
|
|
|
|
ok, msg, steps = run_single_kernel(
|
|
kernel.code, min(n_lanes, 32), args_ptr, kernel.global_size,
|
|
program, max_steps, debug, trace_len, ki
|
|
)
|
|
total_steps += steps
|
|
if not ok:
|
|
return False, msg
|
|
|
|
return True, f"Completed {len(kernels)} kernels, {total_steps} total steps"
|
|
|
|
def compare_emulators_with_memory(kernel: bytes, n_lanes: int, buf_sizes: list, max_steps: int = 1000, debug: bool = False,
|
|
global_size: tuple[int, int, int] = (1, 1, 1), trace_len: int = 10) -> tuple[bool, str]:
|
|
"""Run both emulators with memory set up for tinygrad kernels, executing all workgroups. Legacy wrapper."""
|
|
# Allocate buffers
|
|
buffers = []
|
|
for size in buf_sizes:
|
|
buf = (ctypes.c_uint8 * size)(*[0] * size)
|
|
buffers.append(buf)
|
|
|
|
# Create args array with buffer pointers
|
|
args = (ctypes.c_uint64 * len(buffers))(*[ctypes.addressof(b) for b in buffers])
|
|
args_ptr = ctypes.addressof(args)
|
|
|
|
# Set up valid memory ranges for Python emulator
|
|
ranges = {(ctypes.addressof(b), len(b)) for b in buffers}
|
|
ranges.add((args_ptr, ctypes.sizeof(args)))
|
|
set_valid_mem_ranges(ranges)
|
|
|
|
program = decode_program(kernel)
|
|
ok, msg, _ = run_single_kernel(kernel, n_lanes, args_ptr, global_size, program, max_steps, debug, trace_len)
|
|
return ok, msg
|
|
|
|
def get_kernels_from_tinygrad(op_fn) -> tuple[list[KernelInfo], dict[int, int], dict[int, bytes]]:
|
|
"""Compile a tinygrad operation and extract all kernels with their buffer mappings."""
|
|
from tinygrad import Tensor
|
|
from tinygrad.runtime.support.elf import elf_loader
|
|
|
|
out = op_fn(Tensor)
|
|
sched = out.schedule()
|
|
kernels = []
|
|
buf_pool: dict[int, int] = {} # buffer id -> size
|
|
buf_data: dict[int, bytes] = {} # buffer id -> initial data from COPY
|
|
|
|
for ei in sched:
|
|
lowered = ei.lower()
|
|
if ei.ast.op.name == 'COPY':
|
|
# Handle COPY: extract source data to initialize destination buffer
|
|
if len(lowered.bufs) >= 2:
|
|
dst_buf, src_buf = lowered.bufs[0], lowered.bufs[1]
|
|
dst_id = id(dst_buf)
|
|
if dst_id not in buf_pool:
|
|
buf_pool[dst_id] = dst_buf.nbytes
|
|
# Get source data if it's from numpy/CPU
|
|
if hasattr(src_buf, 'base') and src_buf.base is not None and hasattr(src_buf.base, '_buf'):
|
|
src_data = bytes(src_buf.base._buf)
|
|
buf_data[dst_id] = src_data
|
|
elif ei.ast.op.name == 'SINK':
|
|
if lowered.prg and lowered.prg.p.lib:
|
|
lib = bytes(lowered.prg.p.lib)
|
|
_, sections, _ = elf_loader(lib)
|
|
for sec in sections:
|
|
if sec.name == '.text':
|
|
buf_idxs = []
|
|
buf_sizes = []
|
|
for b in lowered.bufs:
|
|
buf_id = id(b)
|
|
if buf_id not in buf_pool:
|
|
buf_pool[buf_id] = b.nbytes
|
|
buf_idxs.append(buf_id)
|
|
buf_sizes.append(b.nbytes)
|
|
kernels.append(KernelInfo(
|
|
code=bytes(sec.content),
|
|
global_size=tuple(lowered.prg.p.global_size),
|
|
local_size=tuple(lowered.prg.p.local_size),
|
|
buf_idxs=buf_idxs,
|
|
buf_sizes=buf_sizes
|
|
))
|
|
if not kernels: raise RuntimeError("No kernel found")
|
|
return kernels, buf_pool, buf_data
|
|
|
|
def get_kernel_from_tinygrad(op_fn) -> tuple[bytes, tuple[int, int, int], tuple[int, int, int], list]:
|
|
"""Compile a tinygrad operation and extract the last (main) kernel binary. Legacy wrapper."""
|
|
kernels, _, _ = get_kernels_from_tinygrad(op_fn)
|
|
k = kernels[-1]
|
|
return k.code, k.global_size, k.local_size, k.buf_sizes
|
|
|
|
class TestTinygradKernels(unittest.TestCase):
|
|
"""Compare emulators on real tinygrad-compiled kernels."""
|
|
|
|
def _test_kernel(self, op_fn, max_steps=10000):
|
|
kernels, buf_pool, buf_data = get_kernels_from_tinygrad(op_fn)
|
|
ok, msg = compare_emulators_multi_kernel(kernels, buf_pool, max_steps=max_steps, buf_data=buf_data)
|
|
self.assertTrue(ok, msg)
|
|
|
|
# Basic ops - consolidated tests covering key instruction patterns
|
|
def test_unary_ops(self): self._test_kernel(lambda T: T([-1.0, 0.0, 1.0, 2.0]).relu().exp().log().sqrt().reciprocal())
|
|
def test_binary_ops(self): self._test_kernel(lambda T: (T([1.0, 2.0]) + T([3.0, 4.0])) * T([0.5, 0.5]) - T([1.0, 1.0]))
|
|
def test_trig(self): self._test_kernel(lambda T: T([0.1, 1.0, 3.14, -1.0]*8).sin() + T([0.1, 1.0, 3.14, -1.0]*8).cos())
|
|
def test_compare(self): self._test_kernel(lambda T: (T.empty(64) < T.empty(64)).where(T.empty(64), T.empty(64)))
|
|
def test_bitwise(self): self._test_kernel(lambda T: (T([0xF0, 0x0F, 0xFF]*11).int() & T([0x0F, 0x0F, 0x00]*11).int()) | T([1]*33).int())
|
|
def test_int_ops(self): self._test_kernel(lambda T: ((T.empty(64).int() + T.empty(64).int()) * T.empty(64).int()).float())
|
|
|
|
# Reductions
|
|
def test_reduce(self): self._test_kernel(lambda T: T.empty(64).sum() + T.empty(64).max())
|
|
def test_argmax(self): self._test_kernel(lambda T: T.empty(64).argmax())
|
|
|
|
# Matmul
|
|
def test_gemm(self): self._test_kernel(lambda T: T.empty(8, 8) @ T.empty(8, 8), max_steps=100000)
|
|
@unittest.skip("Rust emulator crashes on this kernel (assertion failure in thread.rs)")
|
|
def test_gemm_fp16(self): self._test_kernel(lambda T: T.empty(16, 16).half() @ T.empty(16, 16).half(), max_steps=100000)
|
|
|
|
# Complex ops
|
|
def test_softmax(self): self._test_kernel(lambda T: T.empty(16).softmax())
|
|
def test_layernorm(self): self._test_kernel(lambda T: T.empty(8, 8).layernorm())
|
|
|
|
# Memory patterns
|
|
def test_memory(self): self._test_kernel(lambda T: T.empty(4, 4).permute(1, 0).contiguous() + T.empty(4, 1).expand(4, 4))
|
|
|
|
# Cast ops
|
|
def test_cast(self): self._test_kernel(lambda T: T.empty(32).half().float() + T.empty(32).int().float())
|
|
|
|
# Pooling - regression for VCC wave32 mode
|
|
def test_pool2d(self): self._test_kernel(lambda T: T.empty(1, 1, 8, 8).avg_pool2d(kernel_size=(4,4)) + T.empty(1, 1, 8, 8).max_pool2d(kernel_size=(4,4)))
|
|
|
|
# Convolution
|
|
def test_conv2d(self): self._test_kernel(lambda T: T.empty(1, 2, 8, 8).conv2d(T.empty(2, 2, 3, 3)), max_steps=50000)
|
|
|
|
# Regression tests
|
|
def test_topk(self): self._test_kernel(lambda T: T.empty(64).topk(3)[0])
|
|
def test_interpolate(self): self._test_kernel(lambda T: T.empty(1,2,16,16).relu().cast('uint8').interpolate((8,8), mode="linear"))
|
|
def test_index_int64(self):
|
|
from tinygrad import dtypes
|
|
self._test_kernel(lambda T: T.empty(4, 4)[T.arange(4).cast(dtypes.int64), :])
|
|
def test_gelu(self): self._test_kernel(lambda T: T.empty(32, 32).gelu())
|
|
def test_cross_entropy(self):
|
|
import numpy as np
|
|
np.random.seed(0)
|
|
classes = np.random.randint(0, 10, (16,), dtype=np.int32).tolist()
|
|
x_np = np.random.randn(16, 10).astype(np.float32)
|
|
self._test_kernel(lambda T: (T(x_np.tolist()).reshape(16,10) + 0).cross_entropy((T(classes).int().reshape(16) + 0)))
|
|
def test_isinf(self): self._test_kernel(lambda T: T([float('-inf'), 0., float('inf'), 1.1]*8).isinf())
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|