mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
Node.vars() returns a set and properly dedup (#2356)
* dedup RedNode.vars() * vars returns a set * fix more vars * unused import * update to_movement_ops * comment
This commit is contained in:
@@ -71,8 +71,8 @@ def st_equivalent(st1: ShapeTracker, st2: ShapeTracker):
|
||||
# always invalid
|
||||
if valid1 == 0 and valid2 == 0: return True
|
||||
|
||||
var1 = set(idx1.vars() + valid1.vars())
|
||||
var2 = set(idx2.vars() + valid2.vars())
|
||||
var1 = idx1.vars() | valid1.vars()
|
||||
var2 = idx2.vars() | valid2.vars()
|
||||
# Maybe there are cases that vars are different yet the sts are the same?
|
||||
if var1 != var2: return False
|
||||
|
||||
|
||||
@@ -267,20 +267,25 @@ class TestSymbolicVars(unittest.TestCase):
|
||||
a = Variable("a", 0, 10)
|
||||
b = Variable("b", 0, 10)
|
||||
c = Variable("c", 0, 10)
|
||||
assert z.vars() == z.vars() == []
|
||||
assert a.vars() == a.vars() == [a]
|
||||
assert z.vars() == z.vars() == set()
|
||||
assert a.vars() == a.vars() == {a}
|
||||
m = MulNode(a, 3)
|
||||
assert m.vars() == [a]
|
||||
assert m.vars() == {a}
|
||||
s = SumNode([a, b, c])
|
||||
assert s.vars() == [a, b, c]
|
||||
assert s.vars() == {a, b, c}
|
||||
|
||||
def test_compound(self):
|
||||
a = Variable("a", 0, 10)
|
||||
b = Variable("b", 0, 10)
|
||||
c = Variable("c", 0, 10)
|
||||
assert (a + b * c).vars() == [a, b, c]
|
||||
assert (a % 3 + b // 5).vars() == [a, b]
|
||||
assert (a + b + c - a).vars() == [b, c]
|
||||
assert (a + b * c).vars() == {a, b, c}
|
||||
assert (a % 3 + b // 5).vars() == {a, b}
|
||||
assert (a + b + c - a).vars() == {b, c}
|
||||
|
||||
def test_dedup(self):
|
||||
a = Variable("a", 0, 10)
|
||||
assert (a * a).vars() == {a}
|
||||
assert (a//4 + a//6).vars() == {a}
|
||||
|
||||
class TestSymbolicMinMax(unittest.TestCase):
|
||||
def test_min_max_known(self):
|
||||
|
||||
@@ -162,7 +162,8 @@ def to_image_idx(base_shape:Tuple[int, ...], idxy:Node, valid:Node) -> Tuple[Tup
|
||||
if valid.min == 0 and isinstance(idxy, SumNode):
|
||||
nodes = valid.nodes if isinstance(valid, AndNode) else [valid]
|
||||
val_dict: Dict[Node, Any] = {}
|
||||
idxy_flat_var = [(i, i.vars()[0]) for i in idxy.flat_components if not isinstance(i, NumNode)]
|
||||
# TODO: is this correct? should it check there's only one variable from each component?
|
||||
idxy_flat_var = [(i, list(i.vars())[0]) for i in idxy.flat_components if not isinstance(i, NumNode)]
|
||||
|
||||
for node in nodes:
|
||||
assert isinstance(node, LtNode)
|
||||
|
||||
@@ -80,7 +80,7 @@ def get_single_root(root:LazyBuffer) -> LazyBuffer: return get_single_root(root.
|
||||
def get_movementroot(root:LazyBuffer, allow_contiguous=False) -> LazyBuffer: return get_movementroot(cast(LazyBuffer, root.op.src[0]), allow_contiguous) if not root.realized and (root.optype == MovementOps or (root.op.op == LoadOps.CONTIGUOUS and allow_contiguous and root.op.src[0].st.contiguous)) else root
|
||||
def get_movementroot_contiguous(x:LazyBuffer) -> LazyBuffer: return get_movementroot_contiguous(cast(LazyBuffer, x.op.src[0])) if not x.realized and x.op.op == LoadOps.CONTIGUOUS else (get_movementroot(x, True) if x.optype == MovementOps and x.st.contiguous else x)
|
||||
|
||||
def vars_from_ast(ast:LazyOp) -> List[Variable]: return dedup(functools.reduce(operator.add, [x.arg.st.vars() for x in ast.get_lazyops() if x.op in BufferOps], []))
|
||||
def vars_from_ast(ast:LazyOp) -> Set[Variable]: return functools.reduce(operator.or_, [x.arg.st.vars() for x in ast.get_lazyops() if x.op in BufferOps], set())
|
||||
|
||||
lazycache: WeakValueDictionary = WeakValueDictionary()
|
||||
def create_lazybuffer(device:str, st:ShapeTracker, optype:OpType, op:LazyOp, dtype:DType, base:Optional[LazyBuffer]=None):
|
||||
|
||||
@@ -180,7 +180,7 @@ class BatchExecutor:
|
||||
class ASTRunner:
|
||||
def __init__(self, ast:Optional[LazyOp]):
|
||||
if ast is None:
|
||||
self.op_estimate, self.mem_estimate, self.vars = 0, 0, []
|
||||
self.op_estimate, self.mem_estimate, self.vars = 0, 0, set()
|
||||
else:
|
||||
info = get_lazyop_info(ast)
|
||||
self.op_estimate, self.mem_estimate = info.flops, info.mem_estimate
|
||||
|
||||
@@ -2,9 +2,9 @@
|
||||
from __future__ import annotations
|
||||
import functools, operator
|
||||
from dataclasses import dataclass
|
||||
from typing import Tuple, List, Optional, Dict, cast
|
||||
from typing import Tuple, List, Optional, Dict, Set, cast
|
||||
from tinygrad.ops import MovementOps
|
||||
from tinygrad.helpers import prod, DEBUG, dedup, merge_dicts
|
||||
from tinygrad.helpers import prod, DEBUG, merge_dicts
|
||||
from tinygrad.shape.symbolic import Variable, MulNode, Node, SumNode, NumNode, sint
|
||||
from tinygrad.shape.view import View
|
||||
|
||||
@@ -79,7 +79,7 @@ class ShapeTracker:
|
||||
|
||||
def size(self): return 0 if (0 in self.shape) else self.expr_idxs()[0].max+1
|
||||
|
||||
def vars(self) -> List[Variable]: return dedup(functools.reduce(operator.add, [v.vars() for v in self.views], []))
|
||||
def vars(self) -> Set[Variable]: return functools.reduce(operator.or_, [v.vars() for v in self.views], set())
|
||||
|
||||
@property
|
||||
def var_vals(self) -> Dict[Variable, int]: return merge_dicts([dict([v.unbind()]) for v in self.vars()])
|
||||
|
||||
@@ -3,7 +3,7 @@ import functools
|
||||
from math import gcd
|
||||
from itertools import product
|
||||
from tinygrad.helpers import partition
|
||||
from typing import List, Dict, Callable, Tuple, Type, Union, Optional, Any, Iterator
|
||||
from typing import List, Dict, Callable, Tuple, Type, Union, Optional, Any, Iterator, Set
|
||||
|
||||
# NOTE: Python has different behavior for negative mod and floor div than c
|
||||
# symbolic matches the Python behavior, but the code output is agnostic, and will never have negative numbers in div or mod
|
||||
@@ -18,7 +18,7 @@ class Node:
|
||||
if ops is None: ops = render_python
|
||||
assert self.__class__ in (Variable, NumNode) or self.min != self.max
|
||||
return ops[type(self)](self, ops, ctx)
|
||||
def vars(self): return []
|
||||
def vars(self) -> Set[Variable]: return set()
|
||||
|
||||
def expand_idx(self) -> VariableOrNum: return next((v for v in self.vars() if v.expr is None), NumNode(0))
|
||||
# expand a Node into List[Node] that enumerates the underlying Variables from min to max
|
||||
@@ -149,7 +149,7 @@ class Variable(Node):
|
||||
def unbind(self) -> Tuple[Variable, int]:
|
||||
assert self.val is not None, f"cannot unbind {self}"
|
||||
return Variable(self.expr, self.min, self.max), self.val
|
||||
def vars(self): return [self]
|
||||
def vars(self): return {self}
|
||||
def substitute(self, var_vals: Dict[VariableOrNum, Node]) -> Node: return var_vals[self] if self in var_vals else self
|
||||
|
||||
class NumNode(Node):
|
||||
@@ -173,7 +173,7 @@ class OpNode(Node):
|
||||
def __init__(self, a:Node, b:Union[Node, int]):
|
||||
self.a, self.b = a, b
|
||||
self.min, self.max = self.get_bounds()
|
||||
def vars(self): return self.a.vars() + (self.b.vars() if isinstance(self.b, Node) else [])
|
||||
def vars(self): return self.a.vars() | (self.b.vars() if isinstance(self.b, Node) else set())
|
||||
def get_bounds(self) -> Tuple[int, int]: raise NotImplementedError("must be implemented")
|
||||
|
||||
class LtNode(OpNode):
|
||||
@@ -221,7 +221,7 @@ class ModNode(OpNode):
|
||||
|
||||
class RedNode(Node):
|
||||
def __init__(self, nodes:List[Node]): self.nodes = nodes
|
||||
def vars(self): return functools.reduce(lambda l,x: l+x.vars(), self.nodes, [])
|
||||
def vars(self) -> Set[Variable]: return functools.reduce(lambda l,x: l | x.vars(), self.nodes, set())
|
||||
|
||||
class SumNode(RedNode):
|
||||
@functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
|
||||
|
||||
@@ -2,8 +2,8 @@ from __future__ import annotations
|
||||
import functools, operator
|
||||
from dataclasses import dataclass
|
||||
from typing import Tuple, List, Optional, Dict, cast
|
||||
from tinygrad.helpers import prod, all_int, dedup
|
||||
from tinygrad.shape.symbolic import Node, NumNode, Variable, VariableOrNum, sint
|
||||
from tinygrad.helpers import prod, all_int
|
||||
from tinygrad.shape.symbolic import Node, NumNode, Variable, VariableOrNum, Set, sint
|
||||
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def filter_strides(shape:Tuple[int, ...], strides:Tuple[int, ...]) -> Tuple[int, ...]:
|
||||
@@ -30,9 +30,9 @@ class View:
|
||||
contiguous = offset == 0 and mask is None and all(s1 == s2 for s1,s2 in zip(strides, strides_for_shape(shape)))
|
||||
return View(shape, strides, offset, mask, contiguous)
|
||||
|
||||
def vars(self) -> List[Variable]:
|
||||
def vars(self) -> Set[Variable]:
|
||||
flatten_mask = tuple(x for m in self.mask for x in m) if self.mask is not None else tuple()
|
||||
return dedup(functools.reduce(operator.add, [x.vars() for x in self.shape+self.strides+(self.offset,)+flatten_mask if isinstance(x, Node)], []))
|
||||
return functools.reduce(operator.or_, [x.vars() for x in self.shape+self.strides+(self.offset,)+flatten_mask if isinstance(x, Node)], set())
|
||||
|
||||
def unbind(self) -> View:
|
||||
unbound_vars:Dict[VariableOrNum,Node] = {v: v.unbind()[0] for v in self.vars() if v.val is not None}
|
||||
|
||||
Reference in New Issue
Block a user