From 532b7b018cd8f185abb2cecb682e396f76c887bf Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Thu, 24 Oct 2024 15:10:49 +0700 Subject: [PATCH] add smin/smax (#7253) * add smin/smax * don't create var with var * better test errors * add failing test * enable shape simplification * fix tests * Update view.py * simpler and simplify --- test/helpers.py | 2 +- test/test_symbolic_shapetracker.py | 2 +- test/test_tensor_variable.py | 8 ++++++++ tinygrad/ops.py | 15 ++++++++++++++- tinygrad/shape/view.py | 14 ++++++++------ tinygrad/tensor.py | 7 ++++--- 6 files changed, 36 insertions(+), 12 deletions(-) diff --git a/test/helpers.py b/test/helpers.py index ca75dbdee5..fbf0dd829d 100644 --- a/test/helpers.py +++ b/test/helpers.py @@ -21,7 +21,7 @@ def assert_jit_cache_len(fxn, expected_len): return # until we have a better way of typing the prg in ExecItem if issubclass(type(fxn.jit_cache[0].prg), Runner) and not type(fxn.jit_cache[0].prg).__name__.endswith('Graph'): - assert len(fxn.jit_cache) == expected_len, len(fxn.jit_cache) + assert len(fxn.jit_cache) == expected_len, f"expected {expected_len}, got {len(fxn.jit_cache)}" else: assert len(fxn.jit_cache) == 1, len(fxn.jit_cache) # until we have a better way of typing the prg in ExecItem diff --git a/test/test_symbolic_shapetracker.py b/test/test_symbolic_shapetracker.py index 55b759532e..b876f83074 100644 --- a/test/test_symbolic_shapetracker.py +++ b/test/test_symbolic_shapetracker.py @@ -208,7 +208,7 @@ class TestSymbolicExpand(unittest.TestCase): vi = Variable("i", 1, 5).bind(i) a = Tensor.rand(3, i).reshape(3, vi) a = a + 1 - assert a.shape == (3, vi) + self.assertTupleEqual(a.shape, (3, vi)) class TestSymbolicShrink(unittest.TestCase): def test_shrink_symbols(self): diff --git a/test/test_tensor_variable.py b/test/test_tensor_variable.py index c6592753bd..26448ffd6f 100644 --- a/test/test_tensor_variable.py +++ b/test/test_tensor_variable.py @@ -55,6 +55,14 @@ class TestTensorVariable(unittest.TestCase): ret = t.var().item() assert ret == 0 + def test_symbolic_pad2d(self): + vv = Variable("a", 1, 10).bind(2) + t = Tensor.ones(2, 2).contiguous() + t = t.pad2d([vv, vv, vv, vv]).mean() + ones = 4 + zeros = 6+6+4+4+6+6 + self.assertAlmostEqual(t.item(), ones/(ones+zeros)) + @unittest.skip("symbolic arange isn't supported") def test_symbolic_arange(self): vv = Variable("a", 1, 10).bind(2) diff --git a/tinygrad/ops.py b/tinygrad/ops.py index fac4c41a1a..c99df8cce2 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -164,7 +164,15 @@ def resolve(x, default:bool=True): assert x.dtype is dtypes.bool, "UOp in resolve must be bool" # NOTE: generating the text for the exception is expensive, so we do this return bool(sx.vmin) if (sx:=x.simplify()).vmin == sx.vmax else default -def smax(lst): return max(lst, key=lambda x: x if isinstance(x, int) else x.vmax) + +# smax/smin are replacements for max/min that preserve symbolic +def _suop(lst, uop_fxn, python_fxn): + max_uop, max_num = partition(lst, lambda x: isinstance(x, UOp)) + if len(max_uop): return functools.reduce(uop_fxn, (max_uop + [python_fxn(max_num)]) if len(max_num) else max_uop).ssimplify() + return python_fxn(max_num) +def smax(*lst): return _suop(lst[0] if isinstance(lst[0], (tuple, list)) else lst, UOp.max, max) +def smin(*lst): return _suop(lst[0] if isinstance(lst[0], (tuple, list)) else lst, UOp.min, min) + def ssimplify(uop): return uop.ssimplify() if isinstance(uop, UOp) else uop def sym_infer(uop: Union[UOp, int], var_vals: Dict[UOp, int]) -> int: return uop.sym_infer(var_vals) if isinstance(uop, UOp) else uop @@ -307,6 +315,7 @@ class UOp(MathTrait): @staticmethod def variable(name:str, min_val:ConstType, max_val:ConstType, dtype:DType=dtypes.int): + assert not isinstance(min_val, UOp) and not isinstance(max_val, UOp), f"can't create Variable {name} with {min_val}/{max_val}" return UOp(UOps.DEFINE_VAR, dtype, arg=(name, min_val, max_val)) @property def expr(self): @@ -1056,6 +1065,10 @@ renderer = PatternMatcher([ (UPat(UOps.RANGE, name="x"), lambda x: UOp(UOps.NOOP, arg=f"ridx{x.arg[0]}")), (UPat(UOps.CONST, name="x"), lambda x: UOp(UOps.NOOP, arg=str(x.arg))), (UPat(UOps.BIND, src=UPat(UOps.NOOP), name="x"), lambda x: x.src[0]), + (UPat(UOps.ALU, src=UPat(UOps.NOOP), name="x", arg=UnaryOps.NEG), lambda x: UOp(UOps.NOOP, arg=f"(-{x.src[0].arg})")), + (UPat(UOps.ALU, src=UPat(UOps.NOOP), name="x", arg=BinaryOps.MAX), lambda x: UOp(UOps.NOOP, arg=f"max({x.src[0].arg}, {x.src[1].arg})")), + (UPat(UOps.ALU, src=UPat(UOps.NOOP), name="x", arg=TernaryOps.MULACC), + lambda x: UOp(UOps.NOOP, arg=f"({x.src[0].arg}*{x.src[1].arg}+{x.src[2].arg})")), (UPat(UOps.ALU, src=UPat(UOps.NOOP), name="x", arg=TernaryOps.WHERE), lambda x: UOp(UOps.NOOP, arg=f"({x.src[1].arg} if {x.src[0].arg} else {x.src[2].arg})")), (UPat(UOps.ALU, src=UPat(UOps.NOOP), name="x"), lambda x: UOp(UOps.NOOP, arg=f"({x.src[0].arg}{syms[x.arg]}{x.src[1].arg})")), diff --git a/tinygrad/shape/view.py b/tinygrad/shape/view.py index 2542b99c06..3007b1afda 100644 --- a/tinygrad/shape/view.py +++ b/tinygrad/shape/view.py @@ -1,9 +1,9 @@ from __future__ import annotations import functools, operator, itertools, math from dataclasses import dataclass -from typing import Tuple, List, Optional, Dict, Set, cast, Union +from typing import Tuple, List, Optional, Dict, Set, cast from tinygrad.dtype import dtypes -from tinygrad.ops import resolve, UOp, Variable, sint, sym_infer +from tinygrad.ops import resolve, UOp, Variable, sint, sym_infer, smax, smin from tinygrad.helpers import prod, all_int, argsort @functools.lru_cache(maxsize=None) @@ -125,9 +125,10 @@ class View: offset += sum((strides[i] * mask[i][0]) if e else 0 for i, e in enumerate(elim)) strides = tuple(0 if e else st for st,e in zip(strides, elim)) # simplify as we go - if isinstance(offset, UOp): offset = cast(Union[UOp, int], offset.ssimplify()) + if isinstance(offset, UOp): offset = cast(sint, offset.ssimplify()) + shape = tuple(cast(sint, x.ssimplify()) if isinstance(x, UOp) else x for x in shape) + # TODO: enabling stride simplification breaks it """ - shape = tuple(x.ssimplify() if isinstance(x, UOp) else x for x in shape) strides = tuple(x.ssimplify() if isinstance(x, UOp) else x for x in strides) if mask: mask = tuple((s.ssimplify() if isinstance(s, UOp) else s, e.ssimplify() if isinstance(e, UOp) else e) for s,e in mask) """ @@ -174,6 +175,7 @@ class View: # Merge dimensions in vm2 if required. # NB: Merging too many dimensions can make it difficult to project vm2's mask, hence only combining when required. + if not all_int(vm1.shape): return None idxs: List[UOp] = [UOp.variable(f"idx{i}", 0, s-1) for i,s in enumerate(vm1.shape)] merged_size, merged_term = 1, UOp.const(dtypes.int, 0) extents: List[Tuple[sint, UOp]] = [] @@ -232,9 +234,9 @@ class View: offset = sum([s * x[0] for s, x in zip(self.strides,arg)]) if self.mask: # move the old mask - nmask = tuple([(max(0, min(mx-ax,ay-ax)), max(0, min(my-ax,ay-ax))) for (mx,my),(ax,ay) in zip(self.mask, arg)]) + nmask = tuple([(smax(0, smin(mx-ax,ay-ax)), smax(0, smin(my-ax,ay-ax))) for (mx,my),(ax,ay) in zip(self.mask, arg)]) # merge the masks if we have two - mask = tuple([(max(mx1, mx2), min(my1, my2)) for (mx1, my1), (mx2, my2) in zip(nmask, mask)]) if mask is not None else nmask + mask = tuple([(smax(mx1, mx2), smin(my1, my2)) for (mx1, my1), (mx2, my2) in zip(nmask, mask)]) if mask is not None else nmask shape = [y-x for x,y in arg] if mask is not None and all(m[0] == 0 and m[1] == s for m,s in zip(mask, shape)): mask = None return View.create(tuple(s.ssimplify() if isinstance(s, UOp) else s for s in shape), self.strides, self.offset+offset, mask) diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 5f1e6e59f9..0134e94b4d 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -9,7 +9,7 @@ from tinygrad.dtype import DType, DTypeLike, dtypes, ImageDType, ConstType, leas from tinygrad.helpers import argfix, make_pair, flatten, prod, all_int, round_up, merge_dicts, argsort, getenv, all_same, fully_flatten, dedup from tinygrad.helpers import IMAGE, DEBUG, WINO, _METADATA, Metadata, TRACEMETA, ceildiv, fetch from tinygrad.multi import MultiLazyBuffer -from tinygrad.ops import MetaOps, smax, resolve, UOp, UOps, BinaryOps, sint, Variable +from tinygrad.ops import MetaOps, smax, smin, resolve, UOp, UOps, BinaryOps, sint, Variable from tinygrad.device import Device, Buffer, BufferOptions from tinygrad.engine.lazy import LazyBuffer from tinygrad.engine.realize import run_schedule @@ -380,6 +380,7 @@ class Tensor: if y.op is UOps.ALU: if y.arg is BinaryOps.MUL: return Tensor.from_uop(y.src[0]) * Tensor.from_uop(y.src[1]) if y.arg is BinaryOps.ADD: return Tensor.from_uop(y.src[0]) + Tensor.from_uop(y.src[1]) + if y.arg is BinaryOps.MAX: return Tensor.from_uop(y.src[0]).maximum(Tensor.from_uop(y.src[1])) raise RuntimeError(f"unhandled UOp {y}") # ***** creation entrypoint ***** @@ -1364,9 +1365,9 @@ class Tensor: print(t.pad2d((1, 1, 2, 0), value=-float("inf")).numpy()) ``` """ - pads = tuple((max(p0, 0), max(p1, 0)) for p0, p1 in zip(padding[::2], padding[1::2]))[::-1] + pads = tuple((smax(p0, 0), smax(p1, 0)) for p0, p1 in zip(padding[::2], padding[1::2]))[::-1] padded = self.pad((None,) * (self.ndim - len(padding) // 2) + tuple(pads), value=value) - shrink = tuple((-min(p0, 0), min(p1 + s, s)) for p0, p1, s in zip(padding[::2], padding[1::2], padded.shape[::-1]))[::-1] + shrink = tuple((-smin(p0, 0), smin(p1 + s, s)) for p0, p1, s in zip(padding[::2], padding[1::2], padded.shape[::-1]))[::-1] return padded.shrink((None,) * (self.ndim - len(padding) // 2) + shrink) @property