mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
Add isinstance check before gcd call in SumNode.__lt__ (#2450)
* Add isinstance check before gcd call * Delete blank lines * Fix unit test typo * Delete blank lines again --------- Co-authored-by: Paul Gustafson <paul.gustafson@theambrusgroup.com>
This commit is contained in:
@@ -394,6 +394,14 @@ class TestSymbolicSymbolicOps(unittest.TestCase):
|
||||
assert a < 3 and (a < 3).min == 0 and (a < 3).max == 1
|
||||
assert a > 3 and (a > 3).min == 0 and (a > 3).max == 1
|
||||
|
||||
def test_sumnode_mulnode_lt(self):
|
||||
a = Variable("a", 1, 2)
|
||||
b = Variable("b", 1, 2)
|
||||
c = Variable("c", 1, 2)
|
||||
x = SumNode([MulNode(a, b), c])
|
||||
assert isinstance((x < 3), Node) and (x < 3) == 0
|
||||
assert isinstance((x < 4), LtNode) and (x < 4).min == 0 and (x < 4).max == 1
|
||||
|
||||
def test_num_node_mul_node(self):
|
||||
a = Variable("a", 1, 5)
|
||||
b = NumNode(2) * a
|
||||
@@ -448,6 +456,5 @@ class TestSymbolicSymbolicOps(unittest.TestCase):
|
||||
c = b.substitute({a: NumNode(1)})
|
||||
assert c == NumNode(2)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
@@ -290,7 +290,7 @@ class SumNode(RedNode):
|
||||
if muls:
|
||||
# NOTE: gcd in python 3.8 takes exactly 2 args
|
||||
mul_gcd = b
|
||||
for x in muls: mul_gcd = gcd(mul_gcd, x.b) # type: ignore # mypy cannot tell x.b is int here
|
||||
for x in muls: mul_gcd = gcd(mul_gcd, x.b) if isinstance(x.b, int) else 1
|
||||
all_others = Variable.sum(others)
|
||||
if all_others.min >= 0 and all_others.max < mul_gcd:
|
||||
lhs, b = Variable.sum([mul//mul_gcd for mul in muls]), b//mul_gcd
|
||||
|
||||
Reference in New Issue
Block a user