From d7d078c7f9301bdd52ffe002bc8635264a81fdc1 Mon Sep 17 00:00:00 2001 From: chenyu Date: Sat, 18 Nov 2023 17:44:52 -0500 Subject: [PATCH] Node.vars() returns a set and properly dedup (#2356) * dedup RedNode.vars() * vars returns a set * fix more vars * unused import * update to_movement_ops * comment --- extra/to_movement_ops.py | 4 ++-- test/unit/test_symbolic.py | 19 ++++++++++++------- tinygrad/features/image.py | 3 ++- tinygrad/lazy.py | 2 +- tinygrad/ops.py | 2 +- tinygrad/shape/shapetracker.py | 6 +++--- tinygrad/shape/symbolic.py | 10 +++++----- tinygrad/shape/view.py | 8 ++++---- 8 files changed, 30 insertions(+), 24 deletions(-) diff --git a/extra/to_movement_ops.py b/extra/to_movement_ops.py index 086b07c90d..69560179ac 100644 --- a/extra/to_movement_ops.py +++ b/extra/to_movement_ops.py @@ -71,8 +71,8 @@ def st_equivalent(st1: ShapeTracker, st2: ShapeTracker): # always invalid if valid1 == 0 and valid2 == 0: return True - var1 = set(idx1.vars() + valid1.vars()) - var2 = set(idx2.vars() + valid2.vars()) + var1 = idx1.vars() | valid1.vars() + var2 = idx2.vars() | valid2.vars() # Maybe there are cases that vars are different yet the sts are the same? if var1 != var2: return False diff --git a/test/unit/test_symbolic.py b/test/unit/test_symbolic.py index a452efc596..4d393f8a3a 100644 --- a/test/unit/test_symbolic.py +++ b/test/unit/test_symbolic.py @@ -267,20 +267,25 @@ class TestSymbolicVars(unittest.TestCase): a = Variable("a", 0, 10) b = Variable("b", 0, 10) c = Variable("c", 0, 10) - assert z.vars() == z.vars() == [] - assert a.vars() == a.vars() == [a] + assert z.vars() == z.vars() == set() + assert a.vars() == a.vars() == {a} m = MulNode(a, 3) - assert m.vars() == [a] + assert m.vars() == {a} s = SumNode([a, b, c]) - assert s.vars() == [a, b, c] + assert s.vars() == {a, b, c} def test_compound(self): a = Variable("a", 0, 10) b = Variable("b", 0, 10) c = Variable("c", 0, 10) - assert (a + b * c).vars() == [a, b, c] - assert (a % 3 + b // 5).vars() == [a, b] - assert (a + b + c - a).vars() == [b, c] + assert (a + b * c).vars() == {a, b, c} + assert (a % 3 + b // 5).vars() == {a, b} + assert (a + b + c - a).vars() == {b, c} + + def test_dedup(self): + a = Variable("a", 0, 10) + assert (a * a).vars() == {a} + assert (a//4 + a//6).vars() == {a} class TestSymbolicMinMax(unittest.TestCase): def test_min_max_known(self): diff --git a/tinygrad/features/image.py b/tinygrad/features/image.py index 6cf6444e33..ee8e3517e6 100644 --- a/tinygrad/features/image.py +++ b/tinygrad/features/image.py @@ -162,7 +162,8 @@ def to_image_idx(base_shape:Tuple[int, ...], idxy:Node, valid:Node) -> Tuple[Tup if valid.min == 0 and isinstance(idxy, SumNode): nodes = valid.nodes if isinstance(valid, AndNode) else [valid] val_dict: Dict[Node, Any] = {} - idxy_flat_var = [(i, i.vars()[0]) for i in idxy.flat_components if not isinstance(i, NumNode)] + # TODO: is this correct? should it check there's only one variable from each component? + idxy_flat_var = [(i, list(i.vars())[0]) for i in idxy.flat_components if not isinstance(i, NumNode)] for node in nodes: assert isinstance(node, LtNode) diff --git a/tinygrad/lazy.py b/tinygrad/lazy.py index e116e9f4b6..e8cb0d164c 100644 --- a/tinygrad/lazy.py +++ b/tinygrad/lazy.py @@ -80,7 +80,7 @@ def get_single_root(root:LazyBuffer) -> LazyBuffer: return get_single_root(root. def get_movementroot(root:LazyBuffer, allow_contiguous=False) -> LazyBuffer: return get_movementroot(cast(LazyBuffer, root.op.src[0]), allow_contiguous) if not root.realized and (root.optype == MovementOps or (root.op.op == LoadOps.CONTIGUOUS and allow_contiguous and root.op.src[0].st.contiguous)) else root def get_movementroot_contiguous(x:LazyBuffer) -> LazyBuffer: return get_movementroot_contiguous(cast(LazyBuffer, x.op.src[0])) if not x.realized and x.op.op == LoadOps.CONTIGUOUS else (get_movementroot(x, True) if x.optype == MovementOps and x.st.contiguous else x) -def vars_from_ast(ast:LazyOp) -> List[Variable]: return dedup(functools.reduce(operator.add, [x.arg.st.vars() for x in ast.get_lazyops() if x.op in BufferOps], [])) +def vars_from_ast(ast:LazyOp) -> Set[Variable]: return functools.reduce(operator.or_, [x.arg.st.vars() for x in ast.get_lazyops() if x.op in BufferOps], set()) lazycache: WeakValueDictionary = WeakValueDictionary() def create_lazybuffer(device:str, st:ShapeTracker, optype:OpType, op:LazyOp, dtype:DType, base:Optional[LazyBuffer]=None): diff --git a/tinygrad/ops.py b/tinygrad/ops.py index 109c8a1d58..6221aa93fa 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -180,7 +180,7 @@ class BatchExecutor: class ASTRunner: def __init__(self, ast:Optional[LazyOp]): if ast is None: - self.op_estimate, self.mem_estimate, self.vars = 0, 0, [] + self.op_estimate, self.mem_estimate, self.vars = 0, 0, set() else: info = get_lazyop_info(ast) self.op_estimate, self.mem_estimate = info.flops, info.mem_estimate diff --git a/tinygrad/shape/shapetracker.py b/tinygrad/shape/shapetracker.py index 31433858f2..31ac85ae13 100644 --- a/tinygrad/shape/shapetracker.py +++ b/tinygrad/shape/shapetracker.py @@ -2,9 +2,9 @@ from __future__ import annotations import functools, operator from dataclasses import dataclass -from typing import Tuple, List, Optional, Dict, cast +from typing import Tuple, List, Optional, Dict, Set, cast from tinygrad.ops import MovementOps -from tinygrad.helpers import prod, DEBUG, dedup, merge_dicts +from tinygrad.helpers import prod, DEBUG, merge_dicts from tinygrad.shape.symbolic import Variable, MulNode, Node, SumNode, NumNode, sint from tinygrad.shape.view import View @@ -79,7 +79,7 @@ class ShapeTracker: def size(self): return 0 if (0 in self.shape) else self.expr_idxs()[0].max+1 - def vars(self) -> List[Variable]: return dedup(functools.reduce(operator.add, [v.vars() for v in self.views], [])) + def vars(self) -> Set[Variable]: return functools.reduce(operator.or_, [v.vars() for v in self.views], set()) @property def var_vals(self) -> Dict[Variable, int]: return merge_dicts([dict([v.unbind()]) for v in self.vars()]) diff --git a/tinygrad/shape/symbolic.py b/tinygrad/shape/symbolic.py index dfdadbf543..a7acca9d18 100644 --- a/tinygrad/shape/symbolic.py +++ b/tinygrad/shape/symbolic.py @@ -3,7 +3,7 @@ import functools from math import gcd from itertools import product from tinygrad.helpers import partition -from typing import List, Dict, Callable, Tuple, Type, Union, Optional, Any, Iterator +from typing import List, Dict, Callable, Tuple, Type, Union, Optional, Any, Iterator, Set # NOTE: Python has different behavior for negative mod and floor div than c # symbolic matches the Python behavior, but the code output is agnostic, and will never have negative numbers in div or mod @@ -18,7 +18,7 @@ class Node: if ops is None: ops = render_python assert self.__class__ in (Variable, NumNode) or self.min != self.max return ops[type(self)](self, ops, ctx) - def vars(self): return [] + def vars(self) -> Set[Variable]: return set() def expand_idx(self) -> VariableOrNum: return next((v for v in self.vars() if v.expr is None), NumNode(0)) # expand a Node into List[Node] that enumerates the underlying Variables from min to max @@ -149,7 +149,7 @@ class Variable(Node): def unbind(self) -> Tuple[Variable, int]: assert self.val is not None, f"cannot unbind {self}" return Variable(self.expr, self.min, self.max), self.val - def vars(self): return [self] + def vars(self): return {self} def substitute(self, var_vals: Dict[VariableOrNum, Node]) -> Node: return var_vals[self] if self in var_vals else self class NumNode(Node): @@ -173,7 +173,7 @@ class OpNode(Node): def __init__(self, a:Node, b:Union[Node, int]): self.a, self.b = a, b self.min, self.max = self.get_bounds() - def vars(self): return self.a.vars() + (self.b.vars() if isinstance(self.b, Node) else []) + def vars(self): return self.a.vars() | (self.b.vars() if isinstance(self.b, Node) else set()) def get_bounds(self) -> Tuple[int, int]: raise NotImplementedError("must be implemented") class LtNode(OpNode): @@ -221,7 +221,7 @@ class ModNode(OpNode): class RedNode(Node): def __init__(self, nodes:List[Node]): self.nodes = nodes - def vars(self): return functools.reduce(lambda l,x: l+x.vars(), self.nodes, []) + def vars(self) -> Set[Variable]: return functools.reduce(lambda l,x: l | x.vars(), self.nodes, set()) class SumNode(RedNode): @functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none diff --git a/tinygrad/shape/view.py b/tinygrad/shape/view.py index 4d4b793946..396ec7f8b5 100644 --- a/tinygrad/shape/view.py +++ b/tinygrad/shape/view.py @@ -2,8 +2,8 @@ from __future__ import annotations import functools, operator from dataclasses import dataclass from typing import Tuple, List, Optional, Dict, cast -from tinygrad.helpers import prod, all_int, dedup -from tinygrad.shape.symbolic import Node, NumNode, Variable, VariableOrNum, sint +from tinygrad.helpers import prod, all_int +from tinygrad.shape.symbolic import Node, NumNode, Variable, VariableOrNum, Set, sint @functools.lru_cache(maxsize=None) def filter_strides(shape:Tuple[int, ...], strides:Tuple[int, ...]) -> Tuple[int, ...]: @@ -30,9 +30,9 @@ class View: contiguous = offset == 0 and mask is None and all(s1 == s2 for s1,s2 in zip(strides, strides_for_shape(shape))) return View(shape, strides, offset, mask, contiguous) - def vars(self) -> List[Variable]: + def vars(self) -> Set[Variable]: flatten_mask = tuple(x for m in self.mask for x in m) if self.mask is not None else tuple() - return dedup(functools.reduce(operator.add, [x.vars() for x in self.shape+self.strides+(self.offset,)+flatten_mask if isinstance(x, Node)], [])) + return functools.reduce(operator.or_, [x.vars() for x in self.shape+self.strides+(self.offset,)+flatten_mask if isinstance(x, Node)], set()) def unbind(self) -> View: unbound_vars:Dict[VariableOrNum,Node] = {v: v.unbind()[0] for v in self.vars() if v.val is not None}