assembly/amd: fix all ops tests (#13910)

* assembly/amd: fix all ops tests

* test_ops with smaller sizes

* ds store/load 2addr
This commit is contained in:
George Hotz
2025-12-30 18:01:34 -05:00
committed by GitHub
parent dc27eb48ac
commit 0221b96761
4 changed files with 279 additions and 33 deletions

View File

@@ -21,22 +21,10 @@ tinygrad's dtype tests should pass with and without LLVM. they run in about 12 s
`PYTHONPATH="." AMD=1 PYTHON_REMU=1 MOCKGPU=1 AMD_LLVM=0 pytest -n=12 test/test_dtype_alu.py test/test_dtype.py`
`PYTHONPATH="." AMD=1 PYTHON_REMU=1 MOCKGPU=1 AMD_LLVM=1 pytest -n=12 test/test_dtype_alu.py test/test_dtype.py`
The ops tests also mostly pass, but they are very slow, so you should run them one at a time.
The ops tests also pass, but they are very slow, so you should run them one at a time.
`SKIP_SLOW_TEST=1 PYTHONPATH="." AMD=1 PYTHON_REMU=1 MOCKGPU=1 AMD_LLVM=0 pytest -n=12 test/test_ops.py`
with failure
`FAILED test/test_ops.py::TestOps::test_avg_pool3d - Exception: forward pass failed shape (1, 1, 3, 3, 3):`
`SKIP_SLOW_TEST=1 PYTHONPATH="." AMD=1 PYTHON_REMU=1 MOCKGPU=1 AMD_LLVM=1 pytest -n=12 test/test_ops.py`
with failures
```
FAILED test/test_ops.py::TestOps::test_avg_pool3d - Exception: forward pass failed shape (1, 1, 3, 3, 3):
FAILED test/test_ops.py::TestOps::test_negative_padding_conv2d - Exception: backward pass tensor 0 failed shape (1, 1, 10, 10):
FAILED test/test_ops.py::TestOps::test_unfold - Exception: backward pass tensor 0 failed shape (8,):
FAILED test/test_ops.py::TestOps::test_avg_pool2d_ceil_mode - Exception: backward pass tensor 0 failed shape (1, 1, 6, 6):
```
TODO: make all ops tests pass and add local regression tests.
When something is caught by main tinygrad tests, a local regression test should be added to `extra/assembly/amd/test`. While working with tinygrad, you can dump the assembly with `DEBUG=7`. These tests all pass on real hardware, so if a test is failing with `AMD=1 PYTHON_REMU=1 MOCKGPU=1` it's likely because an instruction is emulated incorrectly. You can test without `MOCKGPU=1` to test on real hardware, if it works on real hardware there's a bug in the emulator.

View File

@@ -82,6 +82,9 @@ FLAT_D16_LOAD = _mem_ops([GLOBALOp, FLATOp], _D16_LOAD_MAP)
FLAT_D16_STORE = _mem_ops([GLOBALOp, FLATOp], _D16_STORE_MAP)
DS_LOAD = {DSOp.DS_LOAD_B32: (1,4,0), DSOp.DS_LOAD_B64: (2,4,0), DSOp.DS_LOAD_B128: (4,4,0), DSOp.DS_LOAD_U8: (1,1,0), DSOp.DS_LOAD_I8: (1,1,1), DSOp.DS_LOAD_U16: (1,2,0), DSOp.DS_LOAD_I16: (1,2,1)}
DS_STORE = {DSOp.DS_STORE_B32: (1,4), DSOp.DS_STORE_B64: (2,4), DSOp.DS_STORE_B128: (4,4), DSOp.DS_STORE_B8: (1,1), DSOp.DS_STORE_B16: (1,2)}
# 2ADDR ops: load/store two values using offset0 and offset1
DS_LOAD_2ADDR = {DSOp.DS_LOAD_2ADDR_B32: 4, DSOp.DS_LOAD_2ADDR_B64: 8}
DS_STORE_2ADDR = {DSOp.DS_STORE_2ADDR_B32: 4, DSOp.DS_STORE_2ADDR_B64: 8}
SMEM_LOAD = {SMEMOp.S_LOAD_B32: 1, SMEMOp.S_LOAD_B64: 2, SMEMOp.S_LOAD_B128: 4, SMEMOp.S_LOAD_B256: 8, SMEMOp.S_LOAD_B512: 16}
# VOPD op -> VOP3 op mapping (VOPD is dual-issue of VOP1/VOP2 ops, use VOP3 enums for pseudocode lookup)
@@ -227,8 +230,10 @@ def exec_scalar(st: WaveState, inst: Inst) -> int:
literal = inst.simm16 if inst_type in (SOPK, SOPP) else st.literal
# Execute compiled function - pass PC in bytes for instructions that need it
# For wave32, mask VCC and EXEC to 32 bits since only the lower 32 bits are relevant
pc_bytes = st.pc * 4
result = fn(s0, s1, 0, d0, st.scc, st.vcc, 0, exec_mask, literal, None, {}, pc=pc_bytes)
vcc32, exec32 = st.vcc & MASK32, exec_mask & MASK32
result = fn(s0, s1, 0, d0, st.scc, vcc32, 0, exec32, literal, None, {}, pc=pc_bytes)
# Apply results
if sdst is not None:
@@ -270,13 +275,31 @@ def exec_vector(st: WaveState, inst: Inst, lane: int, lds: bytearray | None = No
return
if inst_type is DS:
op, addr, vdst = inst.op, (V[inst.addr] + inst.offset0) & 0xffff, inst.vdst
op, addr0, vdst = inst.op, (V[inst.addr] + inst.offset0) & 0xffff, inst.vdst
if op in DS_LOAD:
cnt, sz, sign = DS_LOAD[op]
for i in range(cnt): val = int.from_bytes(lds[addr+i*sz:addr+i*sz+sz], 'little'); V[vdst + i] = _sext(val, sz * 8) & MASK32 if sign else val
for i in range(cnt): val = int.from_bytes(lds[addr0+i*sz:addr0+i*sz+sz], 'little'); V[vdst + i] = _sext(val, sz * 8) & MASK32 if sign else val
elif op in DS_STORE:
cnt, sz = DS_STORE[op]
for i in range(cnt): lds[addr+i*sz:addr+i*sz+sz] = (V[inst.data0 + i] & ((1 << (sz * 8)) - 1)).to_bytes(sz, 'little')
for i in range(cnt): lds[addr0+i*sz:addr0+i*sz+sz] = (V[inst.data0 + i] & ((1 << (sz * 8)) - 1)).to_bytes(sz, 'little')
elif op in DS_LOAD_2ADDR:
# Load two values from addr+offset0*sz and addr+offset1*sz into vdst (B32: 1 dword each, B64: 2 dwords each)
# Note: offsets are scaled by data size (4 for B32, 8 for B64) per AMD ISA
sz = DS_LOAD_2ADDR[op]
addr0 = (V[inst.addr] + inst.offset0 * sz) & 0xffff
addr1 = (V[inst.addr] + inst.offset1 * sz) & 0xffff
cnt = sz // 4 # 1 for B32, 2 for B64
for i in range(cnt): V[vdst + i] = int.from_bytes(lds[addr0+i*4:addr0+i*4+4], 'little')
for i in range(cnt): V[vdst + cnt + i] = int.from_bytes(lds[addr1+i*4:addr1+i*4+4], 'little')
elif op in DS_STORE_2ADDR:
# Store two values from data0 and data1 to addr+offset0*sz and addr+offset1*sz
# Note: offsets are scaled by data size (4 for B32, 8 for B64) per AMD ISA
sz = DS_STORE_2ADDR[op]
addr0 = (V[inst.addr] + inst.offset0 * sz) & 0xffff
addr1 = (V[inst.addr] + inst.offset1 * sz) & 0xffff
cnt = sz // 4
for i in range(cnt): lds[addr0+i*4:addr0+i*4+4] = (V[inst.data0 + i] & MASK32).to_bytes(4, 'little')
for i in range(cnt): lds[addr1+i*4:addr1+i*4+4] = (V[inst.data1 + i] & MASK32).to_bytes(4, 'little')
else: raise NotImplementedError(f"DS op {op}")
return
@@ -387,7 +410,9 @@ def exec_vector(st: WaveState, inst: Inst, lane: int, lds: bytearray | None = No
is_shift_64 = op in (VOP3Op.V_LSHLREV_B64, VOP3Op.V_LSHRREV_B64, VOP3Op.V_ASHRREV_I64)
# 16-bit source ops: use precomputed sets instead of string checks
# Note: must check op_cls to avoid cross-enum value collisions
is_16bit_src = op_cls is VOP3Op and op in _VOP3_16BIT_OPS and op not in _CVT_32_64_SRC_OPS
# VOP3-encoded VOPC 16-bit ops also use opsel (not VGPR bit 7 like non-VOP3 VOPC)
is_16bit_src = (op_cls is VOP3Op and op in _VOP3_16BIT_OPS and op not in _CVT_32_64_SRC_OPS) or \
(inst_type is VOP3 and op_cls is VOPCOp and op in _VOPC_16BIT_OPS)
# VOP2 16-bit ops use f16 inline constants for src0 (vsrc1 is always a VGPR, no inline constants)
is_vop2_16bit = op_cls is VOP2Op and op in _VOP2_16BIT_OPS

View File

@@ -3850,3 +3850,236 @@ class Test64BitLiterals(unittest.TestCase):
lit_bytes = code[-4:]
lit_val = int.from_bytes(lit_bytes, 'little')
self.assertEqual(lit_val, large_val, f"Encoded literal should be {large_val:#x}, got {lit_val:#x}")
class TestWave32VCCBranch(unittest.TestCase):
"""Regression tests for wave32 VCC branch behavior.
In wave32 mode, S_CBRANCH_VCCNZ/VCCZ should only check VCC_LO (lower 32 bits),
ignoring VCC_HI. Bug: emulator was checking full 64-bit VCC, causing incorrect
branches when VCC_LO=0 but VCC_HI!=0."""
def test_cbranch_vccnz_ignores_vcc_hi(self):
"""S_CBRANCH_VCCNZ should NOT branch when VCC_LO=0, even if VCC_HI!=0.
This is the fix for test_avg_pool3d failure where the emulator incorrectly
branched due to stale VCC_HI bits."""
instructions = [
# Set VCC_HI to non-zero (simulating stale bits from previous ops)
s_mov_b32(s[SrcEnum.VCC_HI - 128], 0x80000000), # VCC_HI = 0x80000000
# Set VCC_LO to zero (the condition we're testing)
s_mov_b32(s[SrcEnum.VCC_LO - 128], 0), # VCC_LO = 0
# Now S_CBRANCH_VCCNZ should NOT branch since VCC_LO is 0
# If it doesn't branch, we'll set v0 = 1; if it branches, v0 stays 0
v_mov_b32_e32(v[0], 0),
s_cbranch_vccnz(2), # Skip next instruction if VCC != 0
v_mov_b32_e32(v[0], 1), # This should execute
s_nop(0), # Jump target
]
st = run_program(instructions, n_lanes=1)
# v0 should be 1 because VCC_LO=0 means no branch
self.assertEqual(st.vgpr[0][0], 1, "Should NOT branch when VCC_LO=0 (VCC_HI ignored in wave32)")
def test_cbranch_vccz_ignores_vcc_hi(self):
"""S_CBRANCH_VCCZ should branch when VCC_LO=0, regardless of VCC_HI."""
instructions = [
# Set VCC_HI to non-zero (simulating stale bits)
s_mov_b32(s[SrcEnum.VCC_HI - 128], 0x80000000), # VCC_HI = 0x80000000
# Set VCC_LO to zero
s_mov_b32(s[SrcEnum.VCC_LO - 128], 0), # VCC_LO = 0
# S_CBRANCH_VCCZ should branch since VCC_LO is 0
v_mov_b32_e32(v[0], 0),
s_cbranch_vccz(2), # Skip next instruction if VCC == 0
v_mov_b32_e32(v[0], 1), # This should NOT execute
s_nop(0), # Jump target
]
st = run_program(instructions, n_lanes=1)
# v0 should be 0 because VCC_LO=0 means branch is taken
self.assertEqual(st.vgpr[0][0], 0, "Should branch when VCC_LO=0 (VCC_HI ignored in wave32)")
def test_cbranch_vccnz_branches_on_vcc_lo(self):
"""S_CBRANCH_VCCNZ should branch when VCC_LO!=0."""
instructions = [
# Set VCC_LO to non-zero
s_mov_b32(s[SrcEnum.VCC_LO - 128], 1), # VCC_LO = 1
s_mov_b32(s[SrcEnum.VCC_HI - 128], 0), # VCC_HI = 0
v_mov_b32_e32(v[0], 0),
s_cbranch_vccnz(2), # Skip next instruction if VCC != 0
v_mov_b32_e32(v[0], 1), # This should NOT execute
s_nop(0), # Jump target
]
st = run_program(instructions, n_lanes=1)
# v0 should be 0 because VCC_LO=1 means branch is taken
self.assertEqual(st.vgpr[0][0], 0, "Should branch when VCC_LO!=0")
class TestVOP3VOPC16Bit(unittest.TestCase):
"""Regression tests for VOP3-encoded VOPC 16-bit comparison instructions.
When VOPC comparisons are encoded in VOP3 format, they use opsel bits to select
which 16-bit half of each source to compare.
Bug: Emulator was ignoring opsel and using VGPR bit 7 encoding instead."""
def test_cmp_eq_u16_opsel_lo_lo(self):
"""V_CMP_EQ_U16 VOP3 with opsel=0 compares lo halves."""
# v0 = 0x12340005 (lo=5, hi=0x1234)
# v1 = 0x56780005 (lo=5, hi=0x5678)
# opsel=0: compare lo halves -> 5 == 5 -> true
instructions = [
s_mov_b32(s[2], 0x12340005),
v_mov_b32_e32(v[0], s[2]),
s_mov_b32(s[2], 0x56780005),
v_mov_b32_e32(v[1], s[2]),
VOP3(VOP3Op.V_CMP_EQ_U16, vdst=v[0], src0=v[0], src1=v[1], opsel=0), # dst=s0
]
st = run_program(instructions, n_lanes=1)
# s0 should have bit 0 set (comparison true for lane 0)
self.assertEqual(st.sgpr[0] & 1, 1, "lo==lo should be true: 5==5")
def test_cmp_eq_u16_opsel_hi_hi(self):
"""V_CMP_EQ_U16 VOP3 with opsel=3 compares hi halves."""
# v0 = 0x12340005 (lo=5, hi=0x1234)
# v1 = 0x56780005 (lo=5, hi=0x5678)
# opsel=3 (bits 0 and 1 set): compare hi halves -> 0x1234 != 0x5678 -> false
instructions = [
s_mov_b32(s[2], 0x12340005),
v_mov_b32_e32(v[0], s[2]),
s_mov_b32(s[2], 0x56780005),
v_mov_b32_e32(v[1], s[2]),
VOP3(VOP3Op.V_CMP_EQ_U16, vdst=v[0], src0=v[0], src1=v[1], opsel=3), # dst=s0, hi vs hi
]
st = run_program(instructions, n_lanes=1)
# s0 should have bit 0 clear (comparison false for lane 0)
self.assertEqual(st.sgpr[0] & 1, 0, "hi==hi should be false: 0x1234!=0x5678")
def test_cmp_eq_u16_opsel_hi_hi_equal(self):
"""V_CMP_EQ_U16 VOP3 with opsel=3 compares hi halves (equal case)."""
# v0 = 0x12340005 (lo=5, hi=0x1234)
# v1 = 0x12340009 (lo=9, hi=0x1234)
# opsel=3: compare hi halves -> 0x1234 == 0x1234 -> true
instructions = [
s_mov_b32(s[2], 0x12340005),
v_mov_b32_e32(v[0], s[2]),
s_mov_b32(s[2], 0x12340009),
v_mov_b32_e32(v[1], s[2]),
VOP3(VOP3Op.V_CMP_EQ_U16, vdst=v[0], src0=v[0], src1=v[1], opsel=3), # dst=s0, hi vs hi
]
st = run_program(instructions, n_lanes=1)
# s0 should have bit 0 set (comparison true for lane 0)
self.assertEqual(st.sgpr[0] & 1, 1, "hi==hi should be true: 0x1234==0x1234")
def test_cmp_gt_u16_opsel_hi(self):
"""V_CMP_GT_U16 VOP3 with opsel=3 compares hi halves."""
# v0 = 0x99990005 (lo=5, hi=0x9999)
# v1 = 0x12340005 (lo=5, hi=0x1234)
# opsel=3: compare hi halves -> 0x9999 > 0x1234 -> true
instructions = [
s_mov_b32(s[2], 0x99990005),
v_mov_b32_e32(v[0], s[2]),
s_mov_b32(s[2], 0x12340005),
v_mov_b32_e32(v[1], s[2]),
VOP3(VOP3Op.V_CMP_GT_U16, vdst=v[0], src0=v[0], src1=v[1], opsel=3), # dst=s0, hi vs hi
]
st = run_program(instructions, n_lanes=1)
# s0 should have bit 0 set (comparison true for lane 0)
self.assertEqual(st.sgpr[0] & 1, 1, "hi>hi should be true: 0x9999>0x1234")
class TestDS2Addr(unittest.TestCase):
"""Regression tests for DS_LOAD_2ADDR and DS_STORE_2ADDR instructions.
These ops use offset scaling: offset * sizeof(data) for address calculation.
Bug: Emulator was using offset*4 for both B32 and B64, but B64 needs offset*8."""
def test_ds_store_load_2addr_b32(self):
"""DS_STORE_2ADDR_B32 and DS_LOAD_2ADDR_B32 with offset scaling by 4."""
# Store 0x12345678 at offset0=0 (*4=0) and 0xDEADBEEF at offset1=1 (*4=4)
# Then load them back
instructions = [
v_mov_b32_e32(v[10], 0), # addr base = 0
s_mov_b32(s[2], 0x12345678),
v_mov_b32_e32(v[0], s[2]), # data0
s_mov_b32(s[2], 0xDEADBEEF),
v_mov_b32_e32(v[1], s[2]), # data1
DS(DSOp.DS_STORE_2ADDR_B32, addr=v[10], data0=v[0], data1=v[1], vdst=v[0], offset0=0, offset1=1),
s_waitcnt(lgkmcnt=0),
DS(DSOp.DS_LOAD_2ADDR_B32, addr=v[10], vdst=v[2], offset0=0, offset1=1),
s_waitcnt(lgkmcnt=0),
]
st = run_program(instructions, n_lanes=1)
self.assertEqual(st.vgpr[0][2], 0x12345678, "v2 should have value from offset 0")
self.assertEqual(st.vgpr[0][3], 0xDEADBEEF, "v3 should have value from offset 4")
def test_ds_store_load_2addr_b32_nonzero_offsets(self):
"""DS_STORE_2ADDR_B32 with non-zero offsets (offset*4 scaling)."""
# Store at offset0=2 (*4=8) and offset1=5 (*4=20)
instructions = [
v_mov_b32_e32(v[10], 0), # addr base = 0
s_mov_b32(s[2], 0x11111111),
v_mov_b32_e32(v[0], s[2]),
s_mov_b32(s[2], 0x22222222),
v_mov_b32_e32(v[1], s[2]),
DS(DSOp.DS_STORE_2ADDR_B32, addr=v[10], data0=v[0], data1=v[1], vdst=v[0], offset0=2, offset1=5),
s_waitcnt(lgkmcnt=0),
DS(DSOp.DS_LOAD_2ADDR_B32, addr=v[10], vdst=v[2], offset0=2, offset1=5),
s_waitcnt(lgkmcnt=0),
]
st = run_program(instructions, n_lanes=1)
self.assertEqual(st.vgpr[0][2], 0x11111111, "v2 should have value from offset 8 (2*4)")
self.assertEqual(st.vgpr[0][3], 0x22222222, "v3 should have value from offset 20 (5*4)")
def test_ds_store_load_2addr_b64(self):
"""DS_STORE_2ADDR_B64 and DS_LOAD_2ADDR_B64 with offset scaling by 8."""
# For B64: each value is 8 bytes (2 dwords), offsets scaled by 8
# Store 64-bit value at offset0=0 (*8=0) and another at offset1=1 (*8=8)
instructions = [
v_mov_b32_e32(v[10], 0), # addr base = 0
# First 64-bit value: 0x123456789ABCDEF0
s_mov_b32(s[2], 0x9ABCDEF0),
v_mov_b32_e32(v[0], s[2]), # low dword
s_mov_b32(s[2], 0x12345678),
v_mov_b32_e32(v[1], s[2]), # high dword
# Second 64-bit value: 0xDEADBEEFCAFEBABE
s_mov_b32(s[2], 0xCAFEBABE),
v_mov_b32_e32(v[2], s[2]), # low dword
s_mov_b32(s[2], 0xDEADBEEF),
v_mov_b32_e32(v[3], s[2]), # high dword
DS(DSOp.DS_STORE_2ADDR_B64, addr=v[10], data0=v[0], data1=v[2], vdst=v[0], offset0=0, offset1=1),
s_waitcnt(lgkmcnt=0),
DS(DSOp.DS_LOAD_2ADDR_B64, addr=v[10], vdst=v[4], offset0=0, offset1=1),
s_waitcnt(lgkmcnt=0),
]
st = run_program(instructions, n_lanes=1)
# v4,v5 = first 64-bit value from offset 0
self.assertEqual(st.vgpr[0][4], 0x9ABCDEF0, "v4 should have low dword of first value")
self.assertEqual(st.vgpr[0][5], 0x12345678, "v5 should have high dword of first value")
# v6,v7 = second 64-bit value from offset 8 (1*8)
self.assertEqual(st.vgpr[0][6], 0xCAFEBABE, "v6 should have low dword of second value")
self.assertEqual(st.vgpr[0][7], 0xDEADBEEF, "v7 should have high dword of second value")
def test_ds_2addr_b64_no_overlap(self):
"""DS_LOAD_2ADDR_B64 with adjacent offsets should not overlap.
Regression test: offset1=1 should access bytes 8-15, not overlap with offset0=0 (bytes 0-7)."""
instructions = [
v_mov_b32_e32(v[10], 0),
# Store 4 distinct dwords at addresses 0,4,8,12 using regular DS_STORE
s_mov_b32(s[2], 0x11111111),
v_mov_b32_e32(v[0], s[2]),
ds_store_b32(addr=v[10], data0=v[0], offset0=0),
s_mov_b32(s[2], 0x22222222),
v_mov_b32_e32(v[0], s[2]),
ds_store_b32(addr=v[10], data0=v[0], offset0=4),
s_mov_b32(s[2], 0x33333333),
v_mov_b32_e32(v[0], s[2]),
ds_store_b32(addr=v[10], data0=v[0], offset0=8),
s_mov_b32(s[2], 0x44444444),
v_mov_b32_e32(v[0], s[2]),
ds_store_b32(addr=v[10], data0=v[0], offset0=12),
s_waitcnt(lgkmcnt=0),
# Load with DS_LOAD_2ADDR_B64: offset0=0 should get 0-7, offset1=1 should get 8-15
DS(DSOp.DS_LOAD_2ADDR_B64, addr=v[10], vdst=v[4], offset0=0, offset1=1),
s_waitcnt(lgkmcnt=0),
]
st = run_program(instructions, n_lanes=1)
# v4,v5 from addr 0-7: 0x11111111, 0x22222222
self.assertEqual(st.vgpr[0][4], 0x11111111, "v4 should be 0x11111111")
self.assertEqual(st.vgpr[0][5], 0x22222222, "v5 should be 0x22222222")
# v6,v7 from addr 8-15: 0x33333333, 0x44444444
self.assertEqual(st.vgpr[0][6], 0x33333333, "v6 should be 0x33333333")
self.assertEqual(st.vgpr[0][7], 0x44444444, "v7 should be 0x44444444")

View File

@@ -848,7 +848,7 @@ class TestOps(unittest.TestCase):
helper_test_op([(45,65)], lambda x: x.cos())
helper_test_op([()], lambda x: x.cos())
if not ((getenv("MOCKGPU") and Device.DEFAULT == "NV") or Device.DEFAULT == "WEBGPU"):
helper_test_op(None, lambda x: x.sin(), vals=[[math.nan, math.inf, -math.inf, 0.0]])
helper_test_op(None, lambda x: x.cos(), vals=[[math.nan, math.inf, -math.inf, 0.0]])
helper_test_op(None, lambda x: x.cos(), vals=[[1e1, 1e2, 1e3, 1e4, 1e5, 1e6, -1e1, -1e2, -1e3, -1e4, -1e5, -1e6]],
atol=3e-3, rtol=3e-3, grad_atol=3e-3, grad_rtol=3e-3)
@unittest.skipIf(Device.DEFAULT == "WEBGPU" and platform.system() == "Windows", "Not accurate enough with DirectX backend")
@@ -859,8 +859,8 @@ class TestOps(unittest.TestCase):
helper_test_op([(45,65)], lambda x: x.tan(), low=-5, high=5)
helper_test_op([()], lambda x: x.tan())
if not ((getenv("MOCKGPU") and Device.DEFAULT == "NV") or Device.DEFAULT == "WEBGPU"):
helper_test_op(None, lambda x: x.sin(), vals=[[math.nan, math.inf, -math.inf, 0.0]])
helper_test_op(None, lambda x: x.cos(), vals=[[1e1, 1e2, 1e3, 1e4, 1e5, 1e6, -1e1, -1e2, -1e3, -1e4, -1e5, -1e6]],
helper_test_op(None, lambda x: x.tan(), vals=[[math.nan, math.inf, -math.inf, 0.0]])
helper_test_op(None, lambda x: x.tan(), vals=[[1e1, 1e2, 1e3, 1e4, 1e5, 1e6, -1e1, -1e2, -1e3, -1e4, -1e5, -1e6]],
atol=3e-3, rtol=3e-3, grad_atol=3e-3, grad_rtol=3e-3)
def test_asin(self):
@@ -1655,7 +1655,7 @@ class TestOps(unittest.TestCase):
def test_broadcast_full(self):
for torch_op, tinygrad_op in [(torch.add, Tensor.add), (torch.sub, Tensor.sub), (torch.mul, Tensor.mul),
(torch.div, Tensor.div), (torch.pow, Tensor.pow)]:
for shapes in [((5,13,24,16), (5,1,24,1)), ((1,3,1,7,1), (2,1,5,1,8))]:
for shapes in [((5,3,14,16), (5,1,14,1)), ((1,3,1,7,1), (2,1,5,1,8))]:
with self.subTest(op=torch_op.__name__, shapes=shapes):
if tinygrad_op != Tensor.pow:
helper_test_op(shapes, torch_op, tinygrad_op)
@@ -2078,7 +2078,7 @@ class TestOps(unittest.TestCase):
lambda x,w: Tensor.conv2d(x,w,padding=[1,1,1,1,1,1]), grad_rtol=1e-5)
def test_simple_conv2d_m4(self):
helper_test_op([(1,16,18,18), (16,16,3,3)],
helper_test_op([(1,16,9,9), (16,16,3,3)],
lambda x,w: torch.nn.functional.conv2d(x,w),
lambda x,w: Tensor.conv2d(x,w), atol=1e-05, grad_rtol=1e-5)
@@ -2535,7 +2535,7 @@ class TestOps(unittest.TestCase):
@slow_test
def test_avg_pool2d(self):
shape = (32,2,111,28)
shape = (32,2,11,28)
for ksz in [(2,2), (3,3), (3,2), (5,5), (5,1)]:
with self.subTest(kernel_size=ksz):
helper_test_op([shape],
@@ -2549,7 +2549,7 @@ class TestOps(unittest.TestCase):
@slow_test
def test_avg_pool2d_padding(self):
shape = (32,2,111,28)
shape = (32,2,11,28)
for ksz in [(2,2), (3,3), 2, 3, (3,2)]:
for p in [1, (1,0), (0,1)]:
with self.subTest(kernel_size=ksz, padding=p):
@@ -2557,10 +2557,10 @@ class TestOps(unittest.TestCase):
lambda x: torch.nn.functional.avg_pool2d(x, kernel_size=ksz, padding=p),
lambda x: Tensor.avg_pool2d(x, kernel_size=ksz, padding=p), rtol=1e-5)
with self.assertRaises(ValueError):
Tensor.avg_pool2d(Tensor.randn((32,2,111,28)), kernel_size=(2,2), padding=(1,1,1))
Tensor.avg_pool2d(Tensor.randn((32,2,11,28)), kernel_size=(2,2), padding=(1,1,1))
def test_avg_pool2d_asymmetric_padding(self):
shape = (32,2,111,28)
shape = (32,2,11,28)
for p in [(0,1,0,1), (2,1,2,1), (2,0,2,1)]:
with self.subTest(padding=p):
helper_test_op([shape],
@@ -2571,7 +2571,7 @@ class TestOps(unittest.TestCase):
@slow_test
def test_avg_pool2d_padding_not_counted(self):
shape = (32,2,111,28)
shape = (32,2,11,28)
for ksz in [(2,2), (3,3), 2, 3, (3,2)]:
with self.subTest(kernel_size=ksz):
helper_test_op([shape],
@@ -2607,9 +2607,9 @@ class TestOps(unittest.TestCase):
lambda x: Tensor.avg_pool2d(x, kernel_size=(3,3), stride=3, padding=1, ceil_mode=True, count_include_pad=True))
def test_global_avg_pool2d(self):
helper_test_op([(32,2,111,28)],
lambda x: torch.nn.functional.avg_pool2d(x, kernel_size=(111,28)),
lambda x: Tensor.avg_pool2d(x, kernel_size=(111,28)), rtol=1e-5)
helper_test_op([(32,2,11,28)],
lambda x: torch.nn.functional.avg_pool2d(x, kernel_size=(11,28)),
lambda x: Tensor.avg_pool2d(x, kernel_size=(11,28)), rtol=1e-5)
def test_avg_pool3d(self):
# TODO: AMD_LLVM has larger atol
@@ -3142,10 +3142,10 @@ class TestOps(unittest.TestCase):
lambda x: x.log_softmax(axis=1).nll_loss(Tensor(target), Tensor(weight), reduction=r))
def test_nll_loss_3d_weight(self):
target = np.random.randint(0, 10, (32,3,3,3), dtype=np.int32).tolist()
target = np.random.randint(0, 10, (16,3,3,3), dtype=np.int32).tolist()
weight = np.random.normal(0, 1, (10,)).astype(np.float32).tolist()
for r in ("mean", "sum", "none"):
helper_test_op([(32,10,3,3,3)],
helper_test_op([(16,10,3,3,3)],
lambda x: torch.nn.functional.nll_loss(torch.nn.functional.log_softmax(x, dim=1), torch.tensor(target), torch.tensor(weight), reduction=r),
lambda x: x.log_softmax(axis=1).nll_loss(Tensor(target), Tensor(weight), reduction=r))