From 7a1d96fd7670316fb55bbe8d079d44e2d58fa36b Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Fri, 3 Mar 2023 16:48:14 -0800 Subject: [PATCH] No negative (#632) * behavior is correct without VALIDHACKS * simple div and mod * fix tests * no negative variables * alt form is correct * still correct * bug in mulnode * at least validhacks works now * cleanups * test validhacks, and to_image_idx * cache compare key * tests and __neg__ --- .github/workflows/test.yml | 3 +- test/unit/test_shapetracker.py | 24 ++--- test/unit/test_symbolic.py | 87 +++++++++++++----- tinygrad/codegen/gpu.py | 34 +++++-- tinygrad/shape/symbolic.py | 156 +++++++++++++++++---------------- 5 files changed, 181 insertions(+), 123 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 2039d51b8c..40848dbe77 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -162,7 +162,8 @@ jobs: - name: Test openpilot model run: | ALLOWED_KERNEL_COUNT=197 FLOAT16=1 VALIDHACKS=1 DEBUGCL=1 GPU=1 IMAGE=2 python3 openpilot/compile.py - UNSAFE_FLOAT4=1 DEBUGCL=1 GPU=1 IMAGE=2 python3 openpilot/compile.py + DEBUGCL=1 GPU=1 IMAGE=2 python3 openpilot/compile.py + VALIDHACKS=1 DEBUGCL=1 GPU=1 IMAGE=2 python3 openpilot/compile.py # disabled, this test is flaky testdocker: diff --git a/test/unit/test_shapetracker.py b/test/unit/test_shapetracker.py index a96f5d11de..3dd9babf7a 100644 --- a/test/unit/test_shapetracker.py +++ b/test/unit/test_shapetracker.py @@ -1,7 +1,7 @@ #!/usr/bin/env python import unittest import numpy as np -from tinygrad.helpers import prod +from tinygrad.helpers import prod, all_same from tinygrad.shape import ShapeTracker, View, ZeroView, merge_views from tinygrad.codegen.gpu import to_image_idx @@ -66,14 +66,7 @@ class CheckingShapeTracker: class TestImageShapeTracker(unittest.TestCase): def test_image(self): base_shape = (64, 1024, 4) - - """ - st = ShapeTracker(shape=(8, 64, 128, 3), views=[ - View((1, 66, 130, 32, 1, 1), (0, 4096, 32, 1, 0, 0), -4128), - ZeroView((1, 64, 128, 32, 1, 1), ((0, 1), (-1, 65), (-1, 129), (0, 32), (0, 1), (0, 1))), - View((8, 64, 128, 3), (4, 4160, 32, 4160), 0)]) - offsets = [0,32,64] - """ + print(base_shape) new_view = merge_views( View((1, 66, 130, 32, 1, 1), (0, 4096, 32, 1, 0, 0), -4128), @@ -88,19 +81,18 @@ class TestImageShapeTracker(unittest.TestCase): offsets = [0,32,64,96] print(st.shape) + idys = [] for o in offsets: print("offset:", o) idxy, valid = st.expr_idxs(o) print("idxy:", idxy.render()) print("valids:", [x.render() for x in valid.nodes]) - out = to_image_idx(base_shape, idxy, True) - print(out) - #idx = (idxy//4)%base_shape[1] - #idy = (idxy//(4*base_shape[1]))%base_shape[0] - #idx, idy = [x.a if isinstance(x, ModNode) and x.a.max < x.b*2 else x for x in (idx, idy)] + idx, idy = to_image_idx(base_shape, idxy, valid, True) + idys.append(idy) + print(base_shape, idx.min, idx.max, idy.min, idy.max, idx, idy) - #print("idx:", idx.render()) - #print("idy:", idy.render()) + # y index shouldn't be changing + assert all_same(idys) class TestSimplifyingShapeTracker(unittest.TestCase): def setUp(self): diff --git a/test/unit/test_symbolic.py b/test/unit/test_symbolic.py index 9392b38ae5..41ec6bf46e 100644 --- a/test/unit/test_symbolic.py +++ b/test/unit/test_symbolic.py @@ -1,6 +1,6 @@ #!/usr/bin/env python import unittest -from tinygrad.shape.symbolic import Variable, divn, modn +from tinygrad.shape.symbolic import Variable, NumNode, Node class TestSymbolic(unittest.TestCase): def helper_test_variable(self, v, n, m, s): @@ -23,6 +23,48 @@ class TestSymbolic(unittest.TestCase): self.helper_test_variable(Variable("a", 3, 8)<4, 0, 1, "(a<4)") self.helper_test_variable(Variable("a", 3, 8)<3, 0, 0, "0") self.helper_test_variable(Variable("a", 3, 8)<2, 0, 0, "0") + + def test_div_becomes_num(self): + assert isinstance(Variable("a", 2, 3)//2, NumNode) + + def test_var_becomes_num(self): + assert isinstance(Variable("a", 2, 2), NumNode) + + def test_equality(self): + idx1 = Variable("idx1", 0, 3) + idx2 = Variable("idx2", 0, 3) + assert idx1 == idx1 + assert idx1 != idx2 + assert idx1*4 == idx1*4 + assert idx1*4 != idx1*3 + assert idx1*4 != idx1+4 + assert idx1*4 != idx2*4 + assert idx1+idx2 == idx1+idx2 + assert idx1+idx2 == idx2+idx1 + assert idx1+idx2 != idx2 + + def test_factorize(self): + a = Variable("a", 0, 8) + self.helper_test_variable(a*2+a*3, 0, 8*5, "(a*5)") + + def test_factorize_no_mul(self): + a = Variable("a", 0, 8) + self.helper_test_variable(a+a*3, 0, 8*4, "(a*4)") + + def test_neg(self): + self.helper_test_variable(-Variable("a", 0, 8), -8, 0, "(a*-1)") + + def test_add_1(self): + self.helper_test_variable(Variable("a", 0, 8)+1, 1, 9, "(1+a)") + + def test_add_num_1(self): + self.helper_test_variable(Variable("a", 0, 8)+Variable.num(1), 1, 9, "(1+a)") + + def test_sub_1(self): + self.helper_test_variable(Variable("a", 0, 8)-1, -1, 7, "(-1+a)") + + def test_sub_num_1(self): + self.helper_test_variable(Variable("a", 0, 8)-Variable.num(1), -1, 7, "(-1+a)") def test_mul_0(self): self.helper_test_variable(Variable("a", 0, 8)*0, 0, 0, "0") @@ -45,6 +87,9 @@ class TestSymbolic(unittest.TestCase): def test_div_min_max(self): self.helper_test_variable(Variable("a", 0, 7) // 2, 0, 3, "(a//2)") + def test_div_neg_min_max(self): + self.helper_test_variable(Variable("a", 0, 7) // -2, -3, 0, "((a//2)*-1)") + def test_sum_div_min_max(self): self.helper_test_variable(Variable.sum([Variable("a", 0, 7), Variable("b", 0, 3)]) // 2, 0, 5, "((a+b)//2)") @@ -57,9 +102,9 @@ class TestSymbolic(unittest.TestCase): def test_sum_div_no_factor(self): self.helper_test_variable(Variable.sum([Variable("a", 0, 7)*5, Variable("b", 0, 3)*5]) // 2, 0, 25, "(((a*5)+(b*5))//2)") - @unittest.skip("mod max is wrong") def test_mod_factor(self): - self.helper_test_variable(Variable.sum([Variable("a", 0, 7)*100, Variable("b", 0, 3)*50]) % 100, 0, 50, "((b*50)%100)") + # NOTE: even though the mod max is 50, it can't know this without knowing about the mul + self.helper_test_variable(Variable.sum([Variable("a", 0, 7)*100, Variable("b", 0, 3)*50]) % 100, 0, 99, "((b*50)%100)") def test_sum_div_const(self): self.helper_test_variable(Variable.sum([Variable("a", 0, 7)*4, Variable.num(3)]) // 4, 0, 7, "a") @@ -73,6 +118,9 @@ class TestSymbolic(unittest.TestCase): def test_mul_mul(self): self.helper_test_variable((Variable("a", 0, 5)*10)*9, 0, 5*10*9, "(a*90)") + def test_div_div(self): + self.helper_test_variable((Variable("a", 0, 1800)//10)//9, 0, 20, "(a//90)") + def test_distribute_mul(self): self.helper_test_variable(Variable.sum([Variable("a", 0, 3), Variable("b", 0, 5)])*3, 0, 24, "((a*3)+(b*3))") @@ -86,11 +134,12 @@ class TestSymbolic(unittest.TestCase): self.helper_test_variable(Variable("a", 0, 6)%100, 0, 6, "a") def test_big_mod(self): - self.helper_test_variable(Variable("a", -20, 20)%10, -9, 9, "(a%10)") - self.helper_test_variable(Variable("a", -20, 0)%10, -9, 0, "(a%10)") - self.helper_test_variable(Variable("a", -20, 1)%10, -9, 1, "(a%10)") + # NOTE: we no longer support negative variables + #self.helper_test_variable(Variable("a", -20, 20)%10, -9, 9, "(a%10)") + #self.helper_test_variable(Variable("a", -20, 0)%10, -9, 0, "(a%10)") + #self.helper_test_variable(Variable("a", -20, 1)%10, -9, 1, "(a%10)") self.helper_test_variable(Variable("a", 0, 20)%10, 0, 9, "(a%10)") - self.helper_test_variable(Variable("a", -1, 20)%10, -1, 9, "(a%10)") + #self.helper_test_variable(Variable("a", -1, 20)%10, -1, 9, "(a%10)") def test_gt_remove(self): self.helper_test_variable(Variable("a", 0, 6) >= 25, 0, 0, "0") @@ -107,16 +156,14 @@ class TestSymbolic(unittest.TestCase): self.helper_test_variable(Variable.ands([Variable.num(1), Variable("a", 0, 1)]), 0, 1, "a") def test_mod_factor_negative(self): - # this is technically wrong, if b is 0 the output will be negative - self.helper_test_variable(Variable.sum([Variable.num(-29), Variable("a", 0, 10), Variable("b", 0, 10)*28]) % 28, -1, 9, "((-1+a)%28)") - self.helper_test_variable(Variable.sum([Variable.num(-29), Variable("a", 0, 100), Variable("b", 0, 10)*28]) % 28, -1, 27, "((-1+a)%28)") + self.helper_test_variable(Variable.sum([Variable.num(-29), Variable("a", 0, 10), Variable("b", 0, 10)*28]) % 28, 0, 27, "((27+a)%28)") + self.helper_test_variable(Variable.sum([Variable.num(-29), Variable("a", 0, 100), Variable("b", 0, 10)*28]) % 28, 0, 27, "((27+a)%28)") def test_sum_combine_num(self): self.helper_test_variable(Variable.sum([Variable.num(29), Variable("a", 0, 10), Variable.num(-23)]), 6, 16, "(6+a)") def test_div_factor(self): - # TODO: this isn't right - self.helper_test_variable(Variable.sum([Variable.num(-44), Variable("a", 0, 10)*2, Variable("b", 0, 10)*40]) // 40, -1, 9, "(-1+b)") + self.helper_test_variable(Variable.sum([Variable.num(-40), Variable("a", 0, 10)*2, Variable("b", 0, 10)*40]) // 40, -1, 9, "(-1+b)") def test_mul_div(self): self.helper_test_variable((Variable("a", 0, 10)*4)//4, 0, 10, "a") @@ -132,7 +179,7 @@ class TestSymbolic(unittest.TestCase): class TestSymbolicNumeric(unittest.TestCase): def helper_test_numeric(self, f): - # TODO: why are the negative tests broken? + # TODO: why are the negative tests broken? (even if we did support negative variables) #MIN, MAX = -10, 10 MIN, MAX = 0, 10 # one number @@ -150,15 +197,15 @@ class TestSymbolicNumeric(unittest.TestCase): self.assertLessEqual(v.min, min(values)) self.assertGreaterEqual(v.max, max(values)) - def test_mod_4(self): self.helper_test_numeric(lambda x: modn(x, 4)) - def test_div_4(self): self.helper_test_numeric(lambda x: divn(x, 4)) - def test_plus_1_div_2(self): self.helper_test_numeric(lambda x: divn(x+1, 2)) - def test_plus_1_mod_2(self): self.helper_test_numeric(lambda x: modn(x+1, 2)) + def test_mod_4(self): self.helper_test_numeric(lambda x: (x%4)) + def test_div_4(self): self.helper_test_numeric(lambda x: (x//4)) + def test_plus_1_div_2(self): self.helper_test_numeric(lambda x: (x+1)//2) + def test_plus_1_mod_2(self): self.helper_test_numeric(lambda x: (x+1)%2) def test_times_2(self): self.helper_test_numeric(lambda x: x*2) def test_times_2_plus_3(self): self.helper_test_numeric(lambda x: x*2 + 3) - def test_times_2_plus_3_mod_4(self): self.helper_test_numeric(lambda x: modn(x*2 + 3, 4)) - def test_times_2_plus_3_div_4(self): self.helper_test_numeric(lambda x: divn(x*2 + 3, 4)) - def test_times_2_plus_3_div_4_mod_4(self): self.helper_test_numeric(lambda x: modn(divn(x*2 + 3, 4), 4)) + def test_times_2_plus_3_mod_4(self): self.helper_test_numeric(lambda x: (x*2 + 3)%4) + def test_times_2_plus_3_div_4(self): self.helper_test_numeric(lambda x: (x*2 + 3)//4) + def test_times_2_plus_3_div_4_mod_4(self): self.helper_test_numeric(lambda x: ((x*2 + 3)//4)%4) if __name__ == '__main__': unittest.main() diff --git a/tinygrad/codegen/gpu.py b/tinygrad/codegen/gpu.py index fc1251ef7f..1373d3c9d1 100644 --- a/tinygrad/codegen/gpu.py +++ b/tinygrad/codegen/gpu.py @@ -3,9 +3,9 @@ from collections import defaultdict from typing import Optional, List, Tuple, Dict, Set, Final, NamedTuple from tinygrad.ops import UnaryOps, BinaryOps, ReduceOps, LazyOp, Op, ASTRunner from tinygrad.codegen.ast import ASTKernel, Token, Types -from tinygrad.shape.symbolic import Node, ModNode, DivNode, render_python +from tinygrad.shape.symbolic import Node, MulNode, DivNode, SumNode, Variable, render_python from tinygrad.shape import ShapeTracker -from tinygrad.helpers import getenv, DEBUG, prod +from tinygrad.helpers import getenv, DEBUG, prod, partition # div is different in cl than python render_cl = render_python.copy() @@ -25,11 +25,25 @@ class GPULanguage(NamedTuple): extra_args : List[str] = [] float4 : Optional[str] = None -def to_image_idx(base_shape:Tuple[int, ...], idxy:Node, validhacks=False): - idx = (idxy//4)%base_shape[1] - idy = (idxy//(4*base_shape[1]))%base_shape[0] - if validhacks: idx, idy = [x.a if isinstance(x, ModNode) and x.a.max < x.b*2 else x for x in (idx, idy)] - return f"(int2)({idx.render(render_cl)}, {idy.render(render_cl)})" +def to_image_idx(base_shape:Tuple[int, ...], idxy:Node, valid:Node, validhacks=False) -> Tuple[Node, Node]: + idy = (idxy//(4*base_shape[1])) + if validhacks and valid.min == 0: + idx = (idxy//4) + (idy*-base_shape[1]) + # find the ones in idx that didn't factorize and remove them (TODO: this is not universal) + if isinstance(idx, SumNode): + unfactored, idx_nodes = partition(idx.nodes, lambda x: isinstance(x, MulNode) and x.b == -base_shape[1]) + assert len(unfactored) <= 1 + idx = Variable.sum(idx_nodes) + unfactored = (Variable.sum(unfactored) // base_shape[1]) + idy += unfactored + # ugh really... + if idx.min >= base_shape[1]//2: + idx -= base_shape[1] + idy += 1 + else: + idx = (idxy//4)%base_shape[1] + #print(base_shape, idx.min, idx.max, idy.min, idy.max, idx, idy) + return idx, idy class GPUCodegen(ASTKernel): lang : GPULanguage = GPULanguage() @@ -66,7 +80,8 @@ class GPUCodegen(ASTKernel): v = Token(f"{self.lang.float4}({','.join([to_store[o+j].tok for j in range(4)])})", Types.FLOAT4) if hasattr(self.bufs[buf_index]._buf, "IMAGE"): assert v.typ == Types.FLOAT4, "Image requires upcasting to FLOAT4" - self.kernel.append(f"write_imagef(data{buf_index}, {to_image_idx(self.bufs[buf_index]._base_shape, idxy)}, {v.tok}); /* {self.bufs[buf_index]._base_shape} */\n") + idx, idy = to_image_idx(self.bufs[buf_index]._base_shape, idxy, valid) + self.kernel.append(f"write_imagef(data{buf_index}, (int2)({idx.render(render_cl)}, {idy.render(render_cl)}), {v.tok}); /* {self.bufs[buf_index]._base_shape} */\n") elif v.typ == Types.FLOAT4: self.kernel.append(f"(({self.lang.buffer_prefix}float4*)data{buf_index})[{(idxy//4).render(render_cl)}] = {v.tok};\n") else: @@ -97,7 +112,8 @@ class GPUCodegen(ASTKernel): ldr = const elif hasattr(self.bufs[buf_index]._buf, "IMAGE"): assert should_upcast and can_merge, f"Image requires upcasting to FLOAT4 {self.buftokens[buf_index]}" - ldr = Token(f"read_imagef({self.buftokens[buf_index].tok}, smp, {to_image_idx(self.bufs[buf_index]._base_shape, idxy, VALIDHACKS)}) /* {self.bufs[buf_index]._base_shape} */", Types.FLOAT4) + idx, idy = to_image_idx(self.bufs[buf_index]._base_shape, idxy, valid, VALIDHACKS) + ldr = Token(f"read_imagef({self.buftokens[buf_index].tok}, smp, (int2)({idx.render(render_cl)}, {idy.render(render_cl)})) /* {self.bufs[buf_index]._base_shape} */", Types.FLOAT4) elif should_upcast and can_merge: ldr = Token(f"(({self.lang.buffer_prefix}float4*){self.buftokens[buf_index].tok})[{(idxy//4).render(render_cl)}]", Types.FLOAT4) else: diff --git a/tinygrad/shape/symbolic.py b/tinygrad/shape/symbolic.py index 9464cf01f6..85ae63b24e 100644 --- a/tinygrad/shape/symbolic.py +++ b/tinygrad/shape/symbolic.py @@ -1,48 +1,61 @@ from __future__ import annotations -import math -from typing import List, Dict, Callable, Type +import math, itertools, functools +from typing import List, Dict, Callable, Type, Union from tinygrad.helpers import partition, all_same -# python has different behavior for negative mod and div than c -def divn(x, a): return x//a if isinstance(x, Node) else int(x/a) -def modn(x, a): return x%a if isinstance(x, Node) else (-((-x)%a) if x < 0 else x%a) +# NOTE: Python has different behavior for negative mod and floor div than c +# symbolic matches the Python behavior, but the code is outputs is agnostic, and will never have negative numbers in div or mod + +def create_node(typ:Type[Node], *args): + ret = typ(*args) + assert ret.min <= ret.max, f"min greater than max! {ret.min} {ret.max} when creating {typ} {args}" + if ret.min == ret.max: return NumNode(ret.min) + return ret class Node: b: int min: int max: int - def render(self, ops=None, ctx=None): + def render(self, ops=None, ctx=None) -> str: if ops is None: ops = render_python - if self.min == self.max and type(self) != NumNode: return NumNode(self.min).render(ops, ctx) + assert isinstance(self, NumNode) or self.min != self.max return ops[type(self)](self, ops, ctx) - def __add__(self, b:int): return Variable.sum([self, Variable.num(b)]) if b != 0 else self - def __sub__(self, b:int): return self+-b - def __ge__(self, b:int): return GeNode(self, b) - def __lt__(self, b:int): return LtNode(self, b) + @functools.cached_property + def key(self) -> str: return self.render() + def __repr__(self): return "<"+self.key+">" + def __eq__(self, other:object) -> bool: + if not isinstance(other, Node): return NotImplemented + return self.key == other.key + def __neg__(self): return self*-1 + def __add__(self, b:Union[Node, int]): return Variable.sum([self, b if isinstance(b, Node) else Variable.num(b)]) + def __sub__(self, b:Union[Node, int]): return self+-b + def __ge__(self, b:int): return create_node(GeNode, self, b) + def __lt__(self, b:int): return create_node(LtNode, self, b) def __mul__(self, b:int): if b == 0: return NumNode(0) elif b == 1: return self - if isinstance(self, MulNode): return MulNode(self.a, self.b*b) - # distribute mul into sum - if isinstance(self, SumNode): return Variable.sum([x*b for x in self.nodes]) - return MulNode(self, b) + if isinstance(self, MulNode): return self.a*(self.b*b) # two muls is one mul + if isinstance(self, SumNode): return Variable.sum([x*b for x in self.nodes]) # distribute mul into sum + return create_node(MulNode, self, b) # *** complex ops *** def __floordiv__(self, b:int): assert b != 0 + if b < 0: return (self//-b)*-1 if b == 1: return self - if isinstance(self, MulNode) and modn(self.b, b) == 0: return self.a*divn(self.b, b) - if isinstance(self, MulNode) and modn(b, self.b) == 0: return self.a//divn(b, self.b) + if isinstance(self, DivNode): return self.a//(self.b*b) # two divs is one div + if isinstance(self, MulNode) and self.b % b == 0: return self.a*(self.b//b) + if isinstance(self, MulNode) and b % self.b == 0: return self.a//(b//self.b) if isinstance(self, SumNode): factors, tmp_nofactor = partition(self.nodes, lambda x: (isinstance(x, (MulNode, NumNode))) and x.b%b == 0) nofactor = [] # ugh, i doubt this is universally right for x in tmp_nofactor: if isinstance(x, NumNode): - if modn(x.b, b) != x.b: - factors.append(Variable.num(x.b - modn(x.b, b))) # python does floor division - nofactor.append(Variable.num(modn(x.b, b))) + if (x.b%b) != x.b: + factors.append(Variable.num(x.b - (x.b%b))) # python does floor division + nofactor.append(Variable.num(x.b%b)) else: nofactor.append(x) gcd = [math.gcd(x.b, b) if isinstance(x, (MulNode, NumNode)) else None for x in nofactor] @@ -58,24 +71,28 @@ class Node: for m in muls: if m > 1 and b%m == 0: return (self//m)//(b//m) - return DivNode(self, b) + if self.min < 0: + offset = self.min//b + return (self+offset*b)//b - offset + return create_node(DivNode, self, b) def __mod__(self, b:int): + assert b > 0 if b == 1: return NumNode(0) if isinstance(self, SumNode): new_nodes = [] for x in self.nodes: - if isinstance(x, NumNode): new_nodes.append(Variable.num(modn(x.b, b))) - elif isinstance(x, MulNode): new_nodes.append(x.a * modn(x.b, b)) + if isinstance(x, NumNode): new_nodes.append(Variable.num(x.b%b)) + elif isinstance(x, MulNode): new_nodes.append(x.a * (x.b%b)) else: new_nodes.append(x) a = Variable.sum(new_nodes) elif isinstance(self, MulNode): - a = self.a * modn(self.b, b) + a = self.a * (self.b%b) else: a = self if a.min >= 0 and a.max < b: return a - if a.min == a.max: return Variable.num(modn(a.min, b)) - return ModNode(a, b) + if a.min < 0: return (a + ((a.min//b)*b)) % b + return create_node(ModNode, a, b) @staticmethod def num(num:int) -> Node: return NumNode(num) @@ -90,28 +107,35 @@ class Node: # combine any numbers inside a sum nodes, num_nodes = partition(nodes, lambda x: not isinstance(x, NumNode)) - num_sum = sum([x.b for x in num_nodes]) - # TODO: these can't be merged due to image indexing. it's not clear which idx to group the offset with - if num_sum >= 0: nodes.append(NumNode(num_sum)) - else: - lte_0, rest = partition(num_nodes, lambda x: x.b <= 0) - nodes += [NumNode(x.b) for x in sorted(lte_0, key=lambda x:x.b) if x.b != 0] - if len(rest): nodes += [NumNode(sum([x.b for x in rest]))] + nodes.append(NumNode(sum([x.b for x in num_nodes]))) + + # combine any MulNodes that factorize (big hack sticking the MulNode(x, 1) on things) + nodes, mul_nodes = partition(nodes, lambda x: not isinstance(x, MulNode)) + mul_nodes += [MulNode(x, 1) for x in nodes] + mul_nodes = sorted(mul_nodes, key=lambda x: x.a.render()) # group by equality (ugh, uses render!) + new_nodes = [k * sum(x.b for x in g) for k, g in itertools.groupby(mul_nodes, key=lambda x: x.a)] + nodes = [x if not isinstance(x, MulNode) or x.b != 1 else x.a for x in new_nodes] # filter 0s nodes = [x for x in nodes if x.min != 0 or x.max != 0] - return SumNode(nodes) if len(nodes) > 1 else (nodes[0] if len(nodes) == 1 else NumNode(0)) + return create_node(SumNode, nodes) if len(nodes) > 1 else (nodes[0] if len(nodes) == 1 else NumNode(0)) @staticmethod def ands(nodes:List[Node]) -> Node: if any((x.min == 0 and x.max == 0) for x in nodes): return NumNode(0) + # filter 1s nodes = [x for x in nodes if x.min != x.max] - return AndNode(nodes) if len(nodes) > 1 else (nodes[0] if len(nodes) == 1 else NumNode(1)) + return create_node(AndNode, nodes) if len(nodes) > 1 else (nodes[0] if len(nodes) == 1 else NumNode(1)) # 4 basic node types class Variable(Node): + def __new__(cls, expr:str, nmin:int, nmax:int): + assert nmin >= 0 and nmin <= nmax + if nmin == nmax: return NumNode(nmin) + return super().__new__(cls) + def __init__(self, expr:str, nmin:int, nmax:int): self.expr, self.min, self.max = expr, nmin, nmax @@ -135,50 +159,28 @@ class RedNode(Node): class GeNode(OpNode): minmax = staticmethod(lambda a,b: (int(a.min >= b), int(a.max >= b))) class LtNode(OpNode): minmax = staticmethod(lambda a,b: (int(a.max < b), int(a.min < b))) -class MulNode(OpNode): minmax = staticmethod(lambda a,b: (a.min*b, a.max*b)) -class DivNode(OpNode): minmax = staticmethod(lambda a,b: (divn(a.min, b), divn(a.max, b))) - -# given a number in the range [amin, amax] (inclusive) -# what are the min and max of that number after modding it by b? - -# aka a fast version of: -#values = [modn(rv, b) for rv in range(amin, amax+1)] -#return min(values), max(values) +class MulNode(OpNode): minmax = staticmethod(lambda a,b: (a.min*b, a.max*b) if b >= 0 else (a.max*b, a.min*b)) +class DivNode(OpNode): + @staticmethod + def minmax(a, b): + assert a.min >= 0 + return a.min//b, a.max//b # you have 3 included ranges -# range 1 from min1 -> max1 (smaller than a mod) -# range 2 from max1 -> min2 -# range 3 from min2 -> max2 (smaller than a mod) - -def modrange_negative(amin, amax, b): - assert amin<0 and amax<0 - min1, max1 = amin, math.ceil(amin/b)*b - min2, max2 = math.floor(amax/b)*b, amax - if max1 > min2: return (modn(min1, b), modn(max2, b)) # range 2 doesn't exist, min1 -> max2 is smaller than a mod - if max1 < min2: return (-b+1, 0) # range 2 is the full distance - if min2 == max2: return (modn(min1, b), 0) # range 1 is the only valid - return (-b+1, 0) # range 1 and 3 are valid - -def modrange_positive(amin, amax, b): - assert amin>=0 and amax>=0 - min1, max1 = amin, math.ceil(amin/b)*b - min2, max2 = math.floor(amax/b)*b, amax - if max1 > min2: return (modn(min1, b), modn(max2, b)) # range 2 doesn't exist, min1 -> max2 is smaller than a mod - if max1 < min2: return (0, b-1) # range 2 is the full distance - if min1 == max1: return (0, modn(max2, b)) # range 3 is the only valid - return (0, b-1) # range 1 and 3 are valid - -def modrange(amin, amax, b): - if amin < 0 and amax < 0: - return modrange_negative(amin, amax, b) - if amin >= 0 and amax >= 0: - return modrange_positive(amin, amax, b) - if amin < 0 and amax >= 0: - min1, max1 = modrange_negative(amin, -1, b) - min2, max2 = modrange_positive(0, amax, b) - return min(min1, min2), max(max1, max2) - -class ModNode(OpNode): minmax = staticmethod(lambda a,b: modrange(a.min, a.max, b)) +# range 1 from a.min -> max1 (smaller than a mod) +# range 2 from max1 -> min2 +# range 3 from min2 -> a.max (smaller than a mod) +class ModNode(OpNode): + @staticmethod + def minmax(a, b): + assert a.min >= 0 + #values = [x%b for x in range(a.min, a.max+1)] + #return min(values), max(values) + max1, min2 = math.ceil(a.min/b)*b, math.floor(a.max/b)*b + if max1 < min2: return (0, b-1) # range 2 is the full distance + if max1 > min2: return (a.min%b, a.max%b) # range 2 doesn't exist, a.min -> a.max is smaller than a mod + if a.min == max1: return (0, a.max%b) # range 3 is the only valid + return (0, b-1) # range 1 and 3 are valid # reduce nodes