diff --git a/test/test_symbolic_ops.py b/test/test_symbolic_ops.py index 7ae45c24f4..6b7ea926ce 100644 --- a/test/test_symbolic_ops.py +++ b/test/test_symbolic_ops.py @@ -288,17 +288,22 @@ class TestSymbolicOps(unittest.TestCase): np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=0) def test_conv2d_ceildiv_edge_case(self): - v = Variable('v', 11, 50_000) - val = 39601 - x = Tensor.randn(1, 22, 50_000)[:, :, :v.bind(val)] - weight = Tensor.randn(256, 22, 12) + # tests symbolic ceildiv in conv2d output shape calculation + # val=79 triggers the edge case where old ceildiv simplifies incorrectly: old gives floor=12, correct ceildiv=13 + v = Variable('v', 11, 100) + val = 79 + x_full = Tensor.randn(1, 8, 100) + weight = Tensor.randn(16, 8, 12) - result = x.conv2d(weight=weight, groups=1, stride=6, dilation=1, padding=(3, 3)) + # symbolic version + result = x_full[:, :, :v.bind(val)].conv2d(weight=weight, groups=1, stride=6, dilation=1, padding=(3, 3)) var_val = {v.expr: val} shape = tuple(sym_infer(s, var_val) for s in result.shape) - with self.assertRaises(AssertionError): - self.assertEqual(shape, (1, 256, 6600)) # TODO: fails if ceildiv is incorrect - # TODO: test output is correct + self.assertEqual(shape, (1, 16, 13)) + + # concrete version for comparison + expected = x_full[:, :, :val].conv2d(weight=weight, groups=1, stride=6, dilation=1, padding=(3, 3)) + np.testing.assert_allclose(result[:, :, :13].numpy(), expected.numpy(), atol=1e-5, rtol=1e-5) if __name__ == '__main__': unittest.main() diff --git a/test/unit/test_helpers.py b/test/unit/test_helpers.py index f72486ae52..ff615ee21f 100644 --- a/test/unit/test_helpers.py +++ b/test/unit/test_helpers.py @@ -2,6 +2,7 @@ import ctypes, gzip, unittest, timeit, pickle from tinygrad import Variable from tinygrad.helpers import Context, ContextVar, argfix, colored, word_wrap, is_numpy_ndarray, mv_address, get_contraction, count from tinygrad.helpers import merge_dicts, strip_parens, prod, round_up, fetch, fully_flatten, from_mv, to_mv, polyN, time_to_str, cdiv, cmod, getbits +from tinygrad.helpers import ceildiv from tinygrad.tensor import Tensor, get_shape import numpy as np @@ -120,6 +121,25 @@ class TestRoundUp(unittest.TestCase): self.assertEqual(round_up(232, 24984), 24984) self.assertEqual(round_up(24984, 232), 25056) +class TestCeilDiv(unittest.TestCase): + def test_int(self): + self.assertEqual(ceildiv(10, 3), 4) + self.assertEqual(ceildiv(9, 3), 3) + self.assertEqual(ceildiv(0, 5), 0) + self.assertEqual(ceildiv(1, 5), 1) + def test_symbolic(self): + # tests that ceildiv with UOp uses (num + amt - 1) // amt formula for non-negative num + v = Variable('v', 0, 100) + result = ceildiv(v, 6) + self.assertEqual(result.render(), "((v+5)//6)") + def test_symbolic_negative_offset(self): + # tests ceildiv(v-5, 6) which is used in conv2d output shape + # old implementation incorrectly simplified -(x//-y) to ((v+1)//6-1) for v-5 + # new implementation uses (v-5+5)//6 = v//6 which is correct + v = Variable('v', 11, 100) + result = ceildiv(v - 5, 6) + self.assertEqual(result.render(), "(v//6)") + class TestCount(unittest.TestCase): def test_count_basic(self): c = count(3) diff --git a/tinygrad/helpers.py b/tinygrad/helpers.py index f5660f8f3f..e3abbde9e4 100644 --- a/tinygrad/helpers.py +++ b/tinygrad/helpers.py @@ -43,7 +43,10 @@ def fully_flatten(l): def fromimport(mod, frm): return getattr(__import__(mod, fromlist=[frm]), frm) def _is_balanced(s:str) -> bool: return (d := 0, all((d := d + (c == '(') - (c == ')')) >= 0 for c in s))[1] and d == 0 def strip_parens(fst:str) -> str: return fst[1:-1] if fst[:1]=='(' and fst[-1:]==')' and _is_balanced(fst[1:-1]) else fst -def ceildiv(num, amt): return int(ret) if isinstance((ret:=-(num//-amt)), float) else ret +def ceildiv(num, amt): + # use (num + amt - 1) // amt when num is a UOp and non-negative to avoid C/Python division mismatch + if hasattr(num, 'vmin') and num.vmin >= 0 and (amt > 0 if isinstance(amt, int) else amt.vmin > 0): return (num + amt - 1) // amt + return int(ret) if isinstance((ret:=-(num//-amt)), float) else ret def round_up(num:int, amt:int) -> int: return (num+amt-1)//amt * amt def round_down(num:int, amt:int) -> int: return -round_up(-num, amt) def next_power2(x): return 1 if x == 0 else 1 << (x - 1).bit_length()