mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-06 21:53:53 -05:00
new ceildiv that fixed symbolic conv (#13944)
* new ceildiv that fixed symbolic conv * smaller test case
This commit is contained in:
@@ -288,17 +288,22 @@ class TestSymbolicOps(unittest.TestCase):
|
||||
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=0)
|
||||
|
||||
def test_conv2d_ceildiv_edge_case(self):
|
||||
v = Variable('v', 11, 50_000)
|
||||
val = 39601
|
||||
x = Tensor.randn(1, 22, 50_000)[:, :, :v.bind(val)]
|
||||
weight = Tensor.randn(256, 22, 12)
|
||||
# tests symbolic ceildiv in conv2d output shape calculation
|
||||
# val=79 triggers the edge case where old ceildiv simplifies incorrectly: old gives floor=12, correct ceildiv=13
|
||||
v = Variable('v', 11, 100)
|
||||
val = 79
|
||||
x_full = Tensor.randn(1, 8, 100)
|
||||
weight = Tensor.randn(16, 8, 12)
|
||||
|
||||
result = x.conv2d(weight=weight, groups=1, stride=6, dilation=1, padding=(3, 3))
|
||||
# symbolic version
|
||||
result = x_full[:, :, :v.bind(val)].conv2d(weight=weight, groups=1, stride=6, dilation=1, padding=(3, 3))
|
||||
var_val = {v.expr: val}
|
||||
shape = tuple(sym_infer(s, var_val) for s in result.shape)
|
||||
with self.assertRaises(AssertionError):
|
||||
self.assertEqual(shape, (1, 256, 6600)) # TODO: fails if ceildiv is incorrect
|
||||
# TODO: test output is correct
|
||||
self.assertEqual(shape, (1, 16, 13))
|
||||
|
||||
# concrete version for comparison
|
||||
expected = x_full[:, :, :val].conv2d(weight=weight, groups=1, stride=6, dilation=1, padding=(3, 3))
|
||||
np.testing.assert_allclose(result[:, :, :13].numpy(), expected.numpy(), atol=1e-5, rtol=1e-5)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
@@ -2,6 +2,7 @@ import ctypes, gzip, unittest, timeit, pickle
|
||||
from tinygrad import Variable
|
||||
from tinygrad.helpers import Context, ContextVar, argfix, colored, word_wrap, is_numpy_ndarray, mv_address, get_contraction, count
|
||||
from tinygrad.helpers import merge_dicts, strip_parens, prod, round_up, fetch, fully_flatten, from_mv, to_mv, polyN, time_to_str, cdiv, cmod, getbits
|
||||
from tinygrad.helpers import ceildiv
|
||||
from tinygrad.tensor import Tensor, get_shape
|
||||
import numpy as np
|
||||
|
||||
@@ -120,6 +121,25 @@ class TestRoundUp(unittest.TestCase):
|
||||
self.assertEqual(round_up(232, 24984), 24984)
|
||||
self.assertEqual(round_up(24984, 232), 25056)
|
||||
|
||||
class TestCeilDiv(unittest.TestCase):
|
||||
def test_int(self):
|
||||
self.assertEqual(ceildiv(10, 3), 4)
|
||||
self.assertEqual(ceildiv(9, 3), 3)
|
||||
self.assertEqual(ceildiv(0, 5), 0)
|
||||
self.assertEqual(ceildiv(1, 5), 1)
|
||||
def test_symbolic(self):
|
||||
# tests that ceildiv with UOp uses (num + amt - 1) // amt formula for non-negative num
|
||||
v = Variable('v', 0, 100)
|
||||
result = ceildiv(v, 6)
|
||||
self.assertEqual(result.render(), "((v+5)//6)")
|
||||
def test_symbolic_negative_offset(self):
|
||||
# tests ceildiv(v-5, 6) which is used in conv2d output shape
|
||||
# old implementation incorrectly simplified -(x//-y) to ((v+1)//6-1) for v-5
|
||||
# new implementation uses (v-5+5)//6 = v//6 which is correct
|
||||
v = Variable('v', 11, 100)
|
||||
result = ceildiv(v - 5, 6)
|
||||
self.assertEqual(result.render(), "(v//6)")
|
||||
|
||||
class TestCount(unittest.TestCase):
|
||||
def test_count_basic(self):
|
||||
c = count(3)
|
||||
|
||||
@@ -43,7 +43,10 @@ def fully_flatten(l):
|
||||
def fromimport(mod, frm): return getattr(__import__(mod, fromlist=[frm]), frm)
|
||||
def _is_balanced(s:str) -> bool: return (d := 0, all((d := d + (c == '(') - (c == ')')) >= 0 for c in s))[1] and d == 0
|
||||
def strip_parens(fst:str) -> str: return fst[1:-1] if fst[:1]=='(' and fst[-1:]==')' and _is_balanced(fst[1:-1]) else fst
|
||||
def ceildiv(num, amt): return int(ret) if isinstance((ret:=-(num//-amt)), float) else ret
|
||||
def ceildiv(num, amt):
|
||||
# use (num + amt - 1) // amt when num is a UOp and non-negative to avoid C/Python division mismatch
|
||||
if hasattr(num, 'vmin') and num.vmin >= 0 and (amt > 0 if isinstance(amt, int) else amt.vmin > 0): return (num + amt - 1) // amt
|
||||
return int(ret) if isinstance((ret:=-(num//-amt)), float) else ret
|
||||
def round_up(num:int, amt:int) -> int: return (num+amt-1)//amt * amt
|
||||
def round_down(num:int, amt:int) -> int: return -round_up(-num, amt)
|
||||
def next_power2(x): return 1 if x == 0 else 1 << (x - 1).bit_length()
|
||||
|
||||
Reference in New Issue
Block a user