safe changes for new symbolic [pr] (#6864)

This commit is contained in:
George Hotz
2024-10-03 20:39:15 +08:00
committed by GitHub
parent 17068410e6
commit 4b6732c4f6
8 changed files with 24 additions and 14 deletions

View File

@@ -140,7 +140,7 @@ def beam_search(lin:Kernel, rawbufs:List[Buffer], amt:int, allow_test_size=True,
try:
rawbufs = _ensure_buffer_alloc(rawbufs)
var_vals: Dict[Variable, int] = {k:(k.max+k.min)//2 for k in lin.ast.variables()}
var_vals: Dict[Variable, int] = {k:int(k.vmax+k.vmin)//2 for k in lin.ast.variables()}
exiting, st = False, time.perf_counter()
dev = Device[lin.opts.device]
while not exiting:
@@ -198,7 +198,7 @@ def time_linearizer(lin:Kernel, rawbufs:List[Buffer], allow_test_size=True, max_
assert dev.compiler is not None
rawbufs = _ensure_buffer_alloc(rawbufs)
var_vals: Dict[Variable, int] = {k:(k.max+k.min)//2 for k in lin.ast.variables()}
var_vals: Dict[Variable, int] = {k:int(k.vmax+k.vmin)//2 for k in lin.ast.variables()}
p = lin.to_program()
tms = _time_program(p, dev.compiler.compile(p.src), var_vals, rawbufs,
max_global_size=max_global_size if allow_test_size else None, clear_l2=clear_l2, cnt=cnt, name=to_function_name(lin.name))

View File

@@ -3,7 +3,7 @@ import math
from typing import Tuple, Optional
from tinygrad.helpers import argsort
from tinygrad.dtype import dtypes, DType, sum_acc_dtype
from tinygrad.ops import ReduceOps
from tinygrad.ops import ReduceOps, resolve
from tinygrad.tensor import Function
from tinygrad.lazy import LazyBuffer
from tinygrad.shape.symbolic import sint
@@ -170,7 +170,7 @@ class Max(Function):
# NOTE: this is sum in reverse
class Expand(Function):
def forward(self, x:LazyBuffer, shape:Tuple[int, ...]) -> LazyBuffer:
self.expanded_axis = tuple(i for i, (si, so) in enumerate(zip(x.shape, shape)) if si != so)
self.expanded_axis = tuple(i for i, (si, so) in enumerate(zip(x.shape, shape)) if resolve(si != so))
return x.expand(shape)
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:

View File

@@ -2,7 +2,8 @@ from __future__ import annotations
from typing import Union, Optional, Any, Tuple, List, get_args
from tinygrad.dtype import dtypes, DType, DTypeLike, ConstType, to_dtype
from tinygrad.helpers import prod, getenv, all_int, all_same, DEBUG, _METADATA, Metadata, SPLIT_REDUCEOP
from tinygrad.ops import MetaOps, UnaryOps, BinaryOps, TernaryOps, ReduceOps, Op, exec_alu, python_alu, REDUCE_ALU, identity_element, MathTrait
from tinygrad.ops import MetaOps, UnaryOps, BinaryOps, TernaryOps, ReduceOps, Op, exec_alu, python_alu, REDUCE_ALU
from tinygrad.ops import identity_element, MathTrait, resolve
from tinygrad.shape.symbolic import sint, Variable
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.device import Buffer
@@ -167,7 +168,7 @@ class LazyBuffer(MathTrait):
def _reduce_op(self, op:ReduceOps, axis:Tuple[int, ...]) -> LazyBuffer:
assert all(0 <= x < len(self.shape) for x in axis), f"axis args {axis} out of range for shape {self.shape}"
axis = tuple(sorted([x for x in axis if self.shape[x] != 1]))
axis = tuple(sorted([x for x in axis if resolve(self.shape[x] != 1)]))
if len(axis) == 0: return self
return create_lazybuffer(self.device, ShapeTracker.from_shape(self.st.reduce(axis)), self.dtype, op, axis, (self,))

View File

@@ -154,7 +154,7 @@ END_FOR_UOP = {UOps.IF:(UOps.STORE, UOps.ENDIF), UOps.RANGE:(UOps.ASSIGN, UOps.E
def resolve(x, default:bool=True):
try: return bool(x)
except ValueError: return default
def smax(lst): return max(lst, key=lambda x: x if isinstance(x, int) else x.max)
def smax(lst): return max(lst, key=lambda x: x if isinstance(x, int) else x.vmax)
ucache:WeakValueDictionary[Tuple, UOp] = WeakValueDictionary()
class UOp(MathTrait):
@@ -269,7 +269,7 @@ class UOp(MathTrait):
def vars(self) -> Set[UOp]: return set([x for x in self.sparents if x.op is UOps.DEFINE_VAR])
def variables(self) -> List[Variable]:
st_vars: List[Set[Variable]] = [x.st_arg.vars() for x in self.sparents if x.op in BUFFER_UOPS]
return sorted(set.union(*st_vars, [Variable(x.arg[0], x.arg[1], x.arg[2]) for x in self.vars()]), key=lambda v: v.expr)
return sorted(set.union(*st_vars, [Variable(x.arg[0], x.arg[1], x.arg[2]) for x in self.vars()]), key=lambda v: v.arg)
def const_factor(self) -> int:
"""largest known int that divides self"""
if self.op is UOps.CONST: return self.arg

View File

@@ -52,7 +52,7 @@ class Program:
special_size = self.local_size if u.arg[0][0] == 'l' else self.global_size
assert special_size is not None
special_size[int(u.arg[0][-1])] = u.arg[1]
self.vars = sorted(self.vars, key=lambda v: v.expr)
self.vars = sorted(self.vars, key=lambda v: v.arg)
self.outs = sorted(dedup(self.outs))
self._ran_post_init = True

View File

@@ -11,6 +11,13 @@ class Node:
b: Union[Node, int]
min: int
max: sint
# helpers for the migration
@property
def vmin(self): return self.min
@property
def vmax(self): return self.max
def render(self, ops=None, ctx=None) -> Any:
if ops is None: ops = render_python
assert self.__class__ in (Variable, NumNode) or self.min != self.max
@@ -119,6 +126,8 @@ class Variable(Node):
return super().__new__(cls)
def __getnewargs__(self): return (self.expr, self.min, self.max) # args passed to __new__ when unpickling
@property
def arg(self): return self.expr
def __init__(self, expr:str, nmin:int, nmax:sint):
self.expr, self.min, self.max = expr, nmin, nmax

View File

@@ -93,7 +93,7 @@ class View:
@functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
def size(self) -> int:
# NOTE: Variable and the Node derived from it in symbolic shapes can only have int as max.
ret = prod([x.max if isinstance(x, Node) else x for x in self.shape])
ret = prod([x.vmax if isinstance(x, Node) else x for x in self.shape])
assert isinstance(ret, int), f"{ret=} is not int"
return ret
@@ -174,7 +174,7 @@ class View:
# Try to project vm2's mask on to vm1.
newb, newe, bad = [0] * len(vm1.shape), list(vm1.shape), False
for d2, ((b, e), o, (_, t)) in enumerate(zip(vm2.mask, origin, reversed(extents))):
if not (t.min < b or t.max >= e): continue
if not (t.vmin < b or t.vmax >= e): continue
if not isinstance(o, int) or not isinstance(b, int) or not isinstance(e, int):
bad = True
continue

View File

@@ -10,7 +10,7 @@ from tinygrad.helpers import argfix, make_pair, flatten, prod, all_int, round_up
from tinygrad.helpers import IMAGE, DEBUG, WINO, _METADATA, Metadata, TRACEMETA
from tinygrad.lazy import LazyBuffer
from tinygrad.multi import MultiLazyBuffer
from tinygrad.ops import MetaOps, truncate, smax
from tinygrad.ops import MetaOps, truncate, smax, resolve
from tinygrad.device import Device, Buffer, BufferOptions
from tinygrad.shape.symbolic import sint, Variable, MulNode, SumNode, NumNode, Node
from tinygrad.engine.realize import run_schedule, memory_planner
@@ -1592,7 +1592,7 @@ class Tensor:
"""
output_dtype = self.dtype if dtypes.is_float(self.dtype) else dtypes.float32
numerator = self.cast(sum_acc_dtype(self.dtype)).sum(axis=axis, keepdim=keepdim)
return numerator.div(prod([si for si, so in zip(self.shape, self.sum(axis=axis, keepdim=True).shape) if si != so])).cast(output_dtype)
return numerator.div(prod([si for si, so in zip(self.shape, self.sum(axis=axis, keepdim=True).shape) if resolve(si != so)])).cast(output_dtype)
def var(self, axis:Optional[Union[int, Sequence[int]]]=None, keepdim=False, correction=1):
"""
@@ -1617,7 +1617,7 @@ class Tensor:
```
"""
squares = (self - self.mean(axis=axis, keepdim=True)).square()
n = prod([si for si, so in zip(self.shape, squares.sum(axis=axis, keepdim=True).shape) if si != so])
n = prod([si for si, so in zip(self.shape, squares.sum(axis=axis, keepdim=True).shape) if resolve(si != so)])
return squares.sum(axis=axis, keepdim=keepdim).div(smax([0, n-correction]))
def std(self, axis:Optional[Union[int, Sequence[int]]]=None, keepdim=False, correction=1):