mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
safe changes for new symbolic [pr] (#6864)
This commit is contained in:
@@ -140,7 +140,7 @@ def beam_search(lin:Kernel, rawbufs:List[Buffer], amt:int, allow_test_size=True,
|
||||
|
||||
try:
|
||||
rawbufs = _ensure_buffer_alloc(rawbufs)
|
||||
var_vals: Dict[Variable, int] = {k:(k.max+k.min)//2 for k in lin.ast.variables()}
|
||||
var_vals: Dict[Variable, int] = {k:int(k.vmax+k.vmin)//2 for k in lin.ast.variables()}
|
||||
exiting, st = False, time.perf_counter()
|
||||
dev = Device[lin.opts.device]
|
||||
while not exiting:
|
||||
@@ -198,7 +198,7 @@ def time_linearizer(lin:Kernel, rawbufs:List[Buffer], allow_test_size=True, max_
|
||||
assert dev.compiler is not None
|
||||
|
||||
rawbufs = _ensure_buffer_alloc(rawbufs)
|
||||
var_vals: Dict[Variable, int] = {k:(k.max+k.min)//2 for k in lin.ast.variables()}
|
||||
var_vals: Dict[Variable, int] = {k:int(k.vmax+k.vmin)//2 for k in lin.ast.variables()}
|
||||
p = lin.to_program()
|
||||
tms = _time_program(p, dev.compiler.compile(p.src), var_vals, rawbufs,
|
||||
max_global_size=max_global_size if allow_test_size else None, clear_l2=clear_l2, cnt=cnt, name=to_function_name(lin.name))
|
||||
|
||||
@@ -3,7 +3,7 @@ import math
|
||||
from typing import Tuple, Optional
|
||||
from tinygrad.helpers import argsort
|
||||
from tinygrad.dtype import dtypes, DType, sum_acc_dtype
|
||||
from tinygrad.ops import ReduceOps
|
||||
from tinygrad.ops import ReduceOps, resolve
|
||||
from tinygrad.tensor import Function
|
||||
from tinygrad.lazy import LazyBuffer
|
||||
from tinygrad.shape.symbolic import sint
|
||||
@@ -170,7 +170,7 @@ class Max(Function):
|
||||
# NOTE: this is sum in reverse
|
||||
class Expand(Function):
|
||||
def forward(self, x:LazyBuffer, shape:Tuple[int, ...]) -> LazyBuffer:
|
||||
self.expanded_axis = tuple(i for i, (si, so) in enumerate(zip(x.shape, shape)) if si != so)
|
||||
self.expanded_axis = tuple(i for i, (si, so) in enumerate(zip(x.shape, shape)) if resolve(si != so))
|
||||
return x.expand(shape)
|
||||
|
||||
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
|
||||
|
||||
@@ -2,7 +2,8 @@ from __future__ import annotations
|
||||
from typing import Union, Optional, Any, Tuple, List, get_args
|
||||
from tinygrad.dtype import dtypes, DType, DTypeLike, ConstType, to_dtype
|
||||
from tinygrad.helpers import prod, getenv, all_int, all_same, DEBUG, _METADATA, Metadata, SPLIT_REDUCEOP
|
||||
from tinygrad.ops import MetaOps, UnaryOps, BinaryOps, TernaryOps, ReduceOps, Op, exec_alu, python_alu, REDUCE_ALU, identity_element, MathTrait
|
||||
from tinygrad.ops import MetaOps, UnaryOps, BinaryOps, TernaryOps, ReduceOps, Op, exec_alu, python_alu, REDUCE_ALU
|
||||
from tinygrad.ops import identity_element, MathTrait, resolve
|
||||
from tinygrad.shape.symbolic import sint, Variable
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad.device import Buffer
|
||||
@@ -167,7 +168,7 @@ class LazyBuffer(MathTrait):
|
||||
|
||||
def _reduce_op(self, op:ReduceOps, axis:Tuple[int, ...]) -> LazyBuffer:
|
||||
assert all(0 <= x < len(self.shape) for x in axis), f"axis args {axis} out of range for shape {self.shape}"
|
||||
axis = tuple(sorted([x for x in axis if self.shape[x] != 1]))
|
||||
axis = tuple(sorted([x for x in axis if resolve(self.shape[x] != 1)]))
|
||||
if len(axis) == 0: return self
|
||||
return create_lazybuffer(self.device, ShapeTracker.from_shape(self.st.reduce(axis)), self.dtype, op, axis, (self,))
|
||||
|
||||
|
||||
@@ -154,7 +154,7 @@ END_FOR_UOP = {UOps.IF:(UOps.STORE, UOps.ENDIF), UOps.RANGE:(UOps.ASSIGN, UOps.E
|
||||
def resolve(x, default:bool=True):
|
||||
try: return bool(x)
|
||||
except ValueError: return default
|
||||
def smax(lst): return max(lst, key=lambda x: x if isinstance(x, int) else x.max)
|
||||
def smax(lst): return max(lst, key=lambda x: x if isinstance(x, int) else x.vmax)
|
||||
|
||||
ucache:WeakValueDictionary[Tuple, UOp] = WeakValueDictionary()
|
||||
class UOp(MathTrait):
|
||||
@@ -269,7 +269,7 @@ class UOp(MathTrait):
|
||||
def vars(self) -> Set[UOp]: return set([x for x in self.sparents if x.op is UOps.DEFINE_VAR])
|
||||
def variables(self) -> List[Variable]:
|
||||
st_vars: List[Set[Variable]] = [x.st_arg.vars() for x in self.sparents if x.op in BUFFER_UOPS]
|
||||
return sorted(set.union(*st_vars, [Variable(x.arg[0], x.arg[1], x.arg[2]) for x in self.vars()]), key=lambda v: v.expr)
|
||||
return sorted(set.union(*st_vars, [Variable(x.arg[0], x.arg[1], x.arg[2]) for x in self.vars()]), key=lambda v: v.arg)
|
||||
def const_factor(self) -> int:
|
||||
"""largest known int that divides self"""
|
||||
if self.op is UOps.CONST: return self.arg
|
||||
|
||||
@@ -52,7 +52,7 @@ class Program:
|
||||
special_size = self.local_size if u.arg[0][0] == 'l' else self.global_size
|
||||
assert special_size is not None
|
||||
special_size[int(u.arg[0][-1])] = u.arg[1]
|
||||
self.vars = sorted(self.vars, key=lambda v: v.expr)
|
||||
self.vars = sorted(self.vars, key=lambda v: v.arg)
|
||||
self.outs = sorted(dedup(self.outs))
|
||||
self._ran_post_init = True
|
||||
|
||||
|
||||
@@ -11,6 +11,13 @@ class Node:
|
||||
b: Union[Node, int]
|
||||
min: int
|
||||
max: sint
|
||||
|
||||
# helpers for the migration
|
||||
@property
|
||||
def vmin(self): return self.min
|
||||
@property
|
||||
def vmax(self): return self.max
|
||||
|
||||
def render(self, ops=None, ctx=None) -> Any:
|
||||
if ops is None: ops = render_python
|
||||
assert self.__class__ in (Variable, NumNode) or self.min != self.max
|
||||
@@ -119,6 +126,8 @@ class Variable(Node):
|
||||
return super().__new__(cls)
|
||||
|
||||
def __getnewargs__(self): return (self.expr, self.min, self.max) # args passed to __new__ when unpickling
|
||||
@property
|
||||
def arg(self): return self.expr
|
||||
|
||||
def __init__(self, expr:str, nmin:int, nmax:sint):
|
||||
self.expr, self.min, self.max = expr, nmin, nmax
|
||||
|
||||
@@ -93,7 +93,7 @@ class View:
|
||||
@functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
|
||||
def size(self) -> int:
|
||||
# NOTE: Variable and the Node derived from it in symbolic shapes can only have int as max.
|
||||
ret = prod([x.max if isinstance(x, Node) else x for x in self.shape])
|
||||
ret = prod([x.vmax if isinstance(x, Node) else x for x in self.shape])
|
||||
assert isinstance(ret, int), f"{ret=} is not int"
|
||||
return ret
|
||||
|
||||
@@ -174,7 +174,7 @@ class View:
|
||||
# Try to project vm2's mask on to vm1.
|
||||
newb, newe, bad = [0] * len(vm1.shape), list(vm1.shape), False
|
||||
for d2, ((b, e), o, (_, t)) in enumerate(zip(vm2.mask, origin, reversed(extents))):
|
||||
if not (t.min < b or t.max >= e): continue
|
||||
if not (t.vmin < b or t.vmax >= e): continue
|
||||
if not isinstance(o, int) or not isinstance(b, int) or not isinstance(e, int):
|
||||
bad = True
|
||||
continue
|
||||
|
||||
@@ -10,7 +10,7 @@ from tinygrad.helpers import argfix, make_pair, flatten, prod, all_int, round_up
|
||||
from tinygrad.helpers import IMAGE, DEBUG, WINO, _METADATA, Metadata, TRACEMETA
|
||||
from tinygrad.lazy import LazyBuffer
|
||||
from tinygrad.multi import MultiLazyBuffer
|
||||
from tinygrad.ops import MetaOps, truncate, smax
|
||||
from tinygrad.ops import MetaOps, truncate, smax, resolve
|
||||
from tinygrad.device import Device, Buffer, BufferOptions
|
||||
from tinygrad.shape.symbolic import sint, Variable, MulNode, SumNode, NumNode, Node
|
||||
from tinygrad.engine.realize import run_schedule, memory_planner
|
||||
@@ -1592,7 +1592,7 @@ class Tensor:
|
||||
"""
|
||||
output_dtype = self.dtype if dtypes.is_float(self.dtype) else dtypes.float32
|
||||
numerator = self.cast(sum_acc_dtype(self.dtype)).sum(axis=axis, keepdim=keepdim)
|
||||
return numerator.div(prod([si for si, so in zip(self.shape, self.sum(axis=axis, keepdim=True).shape) if si != so])).cast(output_dtype)
|
||||
return numerator.div(prod([si for si, so in zip(self.shape, self.sum(axis=axis, keepdim=True).shape) if resolve(si != so)])).cast(output_dtype)
|
||||
|
||||
def var(self, axis:Optional[Union[int, Sequence[int]]]=None, keepdim=False, correction=1):
|
||||
"""
|
||||
@@ -1617,7 +1617,7 @@ class Tensor:
|
||||
```
|
||||
"""
|
||||
squares = (self - self.mean(axis=axis, keepdim=True)).square()
|
||||
n = prod([si for si, so in zip(self.shape, squares.sum(axis=axis, keepdim=True).shape) if si != so])
|
||||
n = prod([si for si, so in zip(self.shape, squares.sum(axis=axis, keepdim=True).shape) if resolve(si != so)])
|
||||
return squares.sum(axis=axis, keepdim=keepdim).div(smax([0, n-correction]))
|
||||
|
||||
def std(self, axis:Optional[Union[int, Sequence[int]]]=None, keepdim=False, correction=1):
|
||||
|
||||
Reference in New Issue
Block a user