mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
remove NumNode.int() (#1876)
This commit is contained in:
@@ -358,17 +358,18 @@ class TestSymbolicSymbolicOps(unittest.TestCase):
|
||||
a = Variable("a", 1, 5)
|
||||
b = Variable("b", 6, 9)
|
||||
c = Variable("c", 1, 10)
|
||||
d = Variable("d", 5, 10)
|
||||
# if the value is always the same, it folds to num
|
||||
assert (a < b) == 1
|
||||
# if it remains as a LtNode, bool is always true and we need to test against min to test if it always evals to True
|
||||
assert (a < c).__class__ is LtNode and (a < c).min == 0 and (a < c).max == 1
|
||||
assert (b < a) == 0
|
||||
assert (d < a) == 0
|
||||
# if it remains as a LtNode, bool is always true and (min, max) == (0, 1)
|
||||
assert isinstance((a < c), LtNode) and (a < c).min == 0 and (a < c).max == 1
|
||||
assert a < c
|
||||
assert not (a < c).min
|
||||
assert (a > c).__class__ is LtNode and (a > c).min == 0 and (a > c).max == 1
|
||||
assert not (a > c).min
|
||||
assert isinstance((a > c), LtNode) and (a > c).min == 0 and (a > c).max == 1
|
||||
# same when comparing with a constant
|
||||
assert a < 3
|
||||
assert a > 3
|
||||
assert a < 3 and (a < 3).min == 0 and (a < 3).max == 1
|
||||
assert a > 3 and (a > 3).min == 0 and (a > 3).max == 1
|
||||
|
||||
def test_num_node_mul_node(self):
|
||||
a = Variable("a", 1, 5)
|
||||
|
||||
@@ -156,7 +156,6 @@ class NumNode(Node):
|
||||
assert isinstance(num, int), f"{num} is not an int"
|
||||
self.b:int = num
|
||||
self.min, self.max = num, num
|
||||
def __int__(self): return self.b
|
||||
def __eq__(self, other): return self.b == other
|
||||
def __hash__(self): return self.hash # needed with __eq__ override
|
||||
def substitute(self, var_vals: Dict[VariableOrNum, Node]) -> Node: return self
|
||||
@@ -178,8 +177,9 @@ class LtNode(OpNode):
|
||||
def __mul__(self, b: Union[Node, int]): return (self.a*b) < (self.b*b)
|
||||
def __floordiv__(self, b: Union[Node, int], _=False): return (self.a//b) < (self.b//b)
|
||||
def get_bounds(self) -> Tuple[int, int]:
|
||||
if isinstance(self.b, int): return int(self.a.max < self.b), int(self.a.min < self.b)
|
||||
return (1, 1) if self.a.max < self.b.min else (0, 0) if self.a.min > self.b.max else (0, 1)
|
||||
if isinstance(self.b, int):
|
||||
return (1, 1) if self.a.max < self.b else (0, 0) if self.a.min >= self.b else (0, 1)
|
||||
return (1, 1) if self.a.max < self.b.min else (0, 0) if self.a.min >= self.b.max else (0, 1)
|
||||
def substitute(self, var_vals: Dict[VariableOrNum, Node]) -> Node: return self.a.substitute(var_vals) < (self.b if isinstance(self.b, int) else self.b.substitute(var_vals))
|
||||
|
||||
class MulNode(OpNode):
|
||||
|
||||
@@ -10,7 +10,7 @@ from typing import List, Tuple, Callable, Optional, ClassVar, Type, Union, Seque
|
||||
from tinygrad.helpers import ImageDType, argfix, make_pair, getenv, IMAGE, DEBUG, flatten, DType, dtypes, prod
|
||||
from tinygrad.lazy import LazyBuffer
|
||||
from tinygrad.ops import Device, LoadOps
|
||||
from tinygrad.shape.symbolic import NumNode, sint, all_int
|
||||
from tinygrad.shape.symbolic import sint, all_int
|
||||
|
||||
# An instantiation of the Function is the Context
|
||||
class Function:
|
||||
@@ -303,9 +303,8 @@ class Tensor:
|
||||
if isinstance(s, int):
|
||||
dim_collapsed += 1
|
||||
else:
|
||||
# TODO: replace NumNode with int in shape
|
||||
assert isinstance(dim_shape, (int, NumNode)), f"does not support symbolic shape {dim_shape}"
|
||||
final_shape.append(int(dim_shape))
|
||||
assert isinstance(dim_shape, int), f"does not support symbolic shape {dim_shape}"
|
||||
final_shape.append(dim_shape)
|
||||
if isinstance(s, Tensor):
|
||||
tensors.append(s)
|
||||
dim.append(i-dim_collapsed)
|
||||
|
||||
Reference in New Issue
Block a user