new ceildiv that fixed symbolic conv (#13944)

* new ceildiv that fixed symbolic conv

* smaller test case
This commit is contained in:
chenyu
2026-01-01 09:02:41 -05:00
committed by GitHub
parent 6a5430ab00
commit 17ef4af72c
3 changed files with 37 additions and 9 deletions

View File

@@ -288,17 +288,22 @@ class TestSymbolicOps(unittest.TestCase):
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=0)
def test_conv2d_ceildiv_edge_case(self):
v = Variable('v', 11, 50_000)
val = 39601
x = Tensor.randn(1, 22, 50_000)[:, :, :v.bind(val)]
weight = Tensor.randn(256, 22, 12)
# tests symbolic ceildiv in conv2d output shape calculation
# val=79 triggers the edge case where old ceildiv simplifies incorrectly: old gives floor=12, correct ceildiv=13
v = Variable('v', 11, 100)
val = 79
x_full = Tensor.randn(1, 8, 100)
weight = Tensor.randn(16, 8, 12)
result = x.conv2d(weight=weight, groups=1, stride=6, dilation=1, padding=(3, 3))
# symbolic version
result = x_full[:, :, :v.bind(val)].conv2d(weight=weight, groups=1, stride=6, dilation=1, padding=(3, 3))
var_val = {v.expr: val}
shape = tuple(sym_infer(s, var_val) for s in result.shape)
with self.assertRaises(AssertionError):
self.assertEqual(shape, (1, 256, 6600)) # TODO: fails if ceildiv is incorrect
# TODO: test output is correct
self.assertEqual(shape, (1, 16, 13))
# concrete version for comparison
expected = x_full[:, :, :val].conv2d(weight=weight, groups=1, stride=6, dilation=1, padding=(3, 3))
np.testing.assert_allclose(result[:, :, :13].numpy(), expected.numpy(), atol=1e-5, rtol=1e-5)
if __name__ == '__main__':
unittest.main()

View File

@@ -2,6 +2,7 @@ import ctypes, gzip, unittest, timeit, pickle
from tinygrad import Variable
from tinygrad.helpers import Context, ContextVar, argfix, colored, word_wrap, is_numpy_ndarray, mv_address, get_contraction, count
from tinygrad.helpers import merge_dicts, strip_parens, prod, round_up, fetch, fully_flatten, from_mv, to_mv, polyN, time_to_str, cdiv, cmod, getbits
from tinygrad.helpers import ceildiv
from tinygrad.tensor import Tensor, get_shape
import numpy as np
@@ -120,6 +121,25 @@ class TestRoundUp(unittest.TestCase):
self.assertEqual(round_up(232, 24984), 24984)
self.assertEqual(round_up(24984, 232), 25056)
class TestCeilDiv(unittest.TestCase):
def test_int(self):
self.assertEqual(ceildiv(10, 3), 4)
self.assertEqual(ceildiv(9, 3), 3)
self.assertEqual(ceildiv(0, 5), 0)
self.assertEqual(ceildiv(1, 5), 1)
def test_symbolic(self):
# tests that ceildiv with UOp uses (num + amt - 1) // amt formula for non-negative num
v = Variable('v', 0, 100)
result = ceildiv(v, 6)
self.assertEqual(result.render(), "((v+5)//6)")
def test_symbolic_negative_offset(self):
# tests ceildiv(v-5, 6) which is used in conv2d output shape
# old implementation incorrectly simplified -(x//-y) to ((v+1)//6-1) for v-5
# new implementation uses (v-5+5)//6 = v//6 which is correct
v = Variable('v', 11, 100)
result = ceildiv(v - 5, 6)
self.assertEqual(result.render(), "(v//6)")
class TestCount(unittest.TestCase):
def test_count_basic(self):
c = count(3)

View File

@@ -43,7 +43,10 @@ def fully_flatten(l):
def fromimport(mod, frm): return getattr(__import__(mod, fromlist=[frm]), frm)
def _is_balanced(s:str) -> bool: return (d := 0, all((d := d + (c == '(') - (c == ')')) >= 0 for c in s))[1] and d == 0
def strip_parens(fst:str) -> str: return fst[1:-1] if fst[:1]=='(' and fst[-1:]==')' and _is_balanced(fst[1:-1]) else fst
def ceildiv(num, amt): return int(ret) if isinstance((ret:=-(num//-amt)), float) else ret
def ceildiv(num, amt):
# use (num + amt - 1) // amt when num is a UOp and non-negative to avoid C/Python division mismatch
if hasattr(num, 'vmin') and num.vmin >= 0 and (amt > 0 if isinstance(amt, int) else amt.vmin > 0): return (num + amt - 1) // amt
return int(ret) if isinstance((ret:=-(num//-amt)), float) else ret
def round_up(num:int, amt:int) -> int: return (num+amt-1)//amt * amt
def round_down(num:int, amt:int) -> int: return -round_up(-num, amt)
def next_power2(x): return 1 if x == 0 else 1 << (x - 1).bit_length()