From ebcda8a71429808fb654de1e09e0b98293cbe119 Mon Sep 17 00:00:00 2001 From: chenyu Date: Fri, 8 Sep 2023 09:25:10 -0700 Subject: [PATCH] Move var_vals from ShapeTracker to LazyBuffer (#1819) --- examples/gpt2.py | 6 ++--- examples/llama.py | 10 +++---- test/test_custom_function.py | 2 +- test/test_symbolic_jit.py | 2 +- test/test_symbolic_ops.py | 2 +- test/test_symbolic_shapetracker.py | 30 ++++++++++----------- tinygrad/codegen/linearizer.py | 2 +- tinygrad/jit.py | 2 +- tinygrad/lazy.py | 43 +++++++++++++++++++----------- tinygrad/ops.py | 4 +-- tinygrad/shape/shapetracker.py | 21 ++++----------- 11 files changed, 63 insertions(+), 61 deletions(-) diff --git a/examples/gpt2.py b/examples/gpt2.py index 0ea0eaaa5f..2fd5bb8596 100644 --- a/examples/gpt2.py +++ b/examples/gpt2.py @@ -76,8 +76,8 @@ class TransformerBlock: cache_k = cache_k.reshape(cache_k.shape[0], start_pos_var, cache_k.shape[2], cache_k.shape[3]) cache_v = cache_v.reshape(cache_v.shape[0], start_pos_var, cache_v.shape[2], cache_v.shape[3]) # need this because we don't reshape back to int shape in the jitted path and we don't have the correct var_vars in cache - cache_k.lazydata.st.var_vals[start_pos_var] = start_pos - cache_v.lazydata.st.var_vals[start_pos_var] = start_pos + cache_k.lazydata.var_vals[start_pos_var] = start_pos + cache_v.lazydata.var_vals[start_pos_var] = start_pos output, cache_k, cache_v = self.attn(self.ln_1(x), cache_k, cache_v, start_pos, mask, jit_ctx=jit_ctx) h = x + output @@ -113,7 +113,7 @@ class Transformer: if seqlen == 1 and start_pos > 0 and getenv("JIT"): start_pos_var = Variable("start_pos", 1, MAX_CONTEXT) pos = self.allpos.shrink(((0, self.allpos.shape[0]), (start_pos_var, start_pos_var+seqlen))) - pos.lazydata.st.var_vals[start_pos_var] = start_pos + pos.lazydata.var_vals[start_pos_var] = start_pos h = self.embed_jitted(tokens, pos) for i, (hi, (cache_k, cache_v)) in enumerate(zip(self.h_jitted, self.kv_caches)): h, cache_k, cache_v = hi(h, cache_k, cache_v, start_pos=start_pos, mask=None, jit_ctx={start_pos_var: start_pos}) diff --git a/examples/llama.py b/examples/llama.py index 5dcc37f103..7c9801a4e9 100755 --- a/examples/llama.py +++ b/examples/llama.py @@ -82,7 +82,7 @@ class Attention: keys, values = xk, xv else: assert cache_k is not None and cache_v is not None, "no cache" - assert start_pos == sym_infer(cache_k.shape[1], cache_k.lazydata.st.var_vals) == sym_infer(cache_v.shape[1], cache_v.lazydata.st.var_vals), f"cache has wrong shape, not ({start_pos} == {sym_infer(cache_k.shape[1], cache_k.lazydata.st.var_vals)} == {sym_infer(cache_v.shape[1], cache_v.lazydata.st.var_vals)})" + assert start_pos == sym_infer(cache_k.shape[1], cache_k.lazydata.var_vals) == sym_infer(cache_v.shape[1], cache_v.lazydata.var_vals), f"cache has wrong shape, not ({start_pos} == {sym_infer(cache_k.shape[1], cache_k.lazydata.var_vals)} == {sym_infer(cache_v.shape[1], cache_v.lazydata.var_vals)})" assert seqlen == xk.shape[1] and seqlen == xv.shape[1], "seqlen is wrong shape?!?" keys, values = cache_k.cat(xk, dim=1), cache_v.cat(xv, dim=1) @@ -121,12 +121,12 @@ class TransformerBlock: cache_k = cache_k.reshape(cache_k.shape[0], pos, cache_k.shape[2], cache_k.shape[3]) cache_v = cache_v.reshape(cache_v.shape[0], pos, cache_v.shape[2], cache_v.shape[3]) # need this because we don't reshape back to int shape in the jitted path and we don't have the correct var_vars in cache - cache_k.lazydata.st.var_vals[pos] = start_pos - cache_v.lazydata.st.var_vals[pos] = start_pos + cache_k.lazydata.var_vals[pos] = start_pos + cache_v.lazydata.var_vals[pos] = start_pos # get only the part of freqs_cis that we are using. freqs_cis = freqs_cis.shrink(((0, freqs_cis.shape[0]), (pos, pos+seqlen), (0, freqs_cis.shape[2]), (0, freqs_cis.shape[3]), (0, freqs_cis.shape[4]))) - freqs_cis.lazydata.st.var_vals[pos] = start_pos + freqs_cis.lazydata.var_vals[pos] = start_pos else: freqs_cis = freqs_cis.shrink(((0, freqs_cis.shape[0]), (start_pos, start_pos+seqlen), (0, freqs_cis.shape[2]), (0, freqs_cis.shape[3]), (0, freqs_cis.shape[4]))) @@ -158,7 +158,7 @@ class Transformer: if seqlen == 1 and JIT: pos = Variable("pos", 1, 1024) freqs_cis = self.freqs_cis.shrink(((0, self.freqs_cis.shape[0]), (pos, pos+seqlen),(0, self.freqs_cis.shape[2]),(0, self.freqs_cis.shape[3]),(0, self.freqs_cis.shape[4]))) - freqs_cis.lazydata.st.var_vals[pos] = start_pos + freqs_cis.lazydata.var_vals[pos] = start_pos h = self.tok_embeddings_jitted(tokens) for i, (layer, (cache_k, cache_v)) in enumerate(zip(self.layers_jitted, self.kv_caches)): h, cache_k, cache_v = layer(h, cache_k, cache_v, start_pos=start_pos, freqs_cis=self.freqs_cis, mask=None, jit_ctx={pos: start_pos}) diff --git a/test/test_custom_function.py b/test/test_custom_function.py index 34daceda62..2644d781c9 100644 --- a/test/test_custom_function.py +++ b/test/test_custom_function.py @@ -43,7 +43,7 @@ class ATan2(Function): assert prod(a.shape) == prod(b.shape) and a.device == b.device, "shape or device mismatch" self.a, self.b = a, b ast = LazyOp(LoadOps.CUSTOM, (a.contiguous(), b.contiguous()), {"GPU": atan2_gpu, "CPU": atan2_cpu}[a.device]) - return create_lazybuffer(a.device, ShapeTracker(a.shape), LoadOps, ast, max(a.dtype, b.dtype)) + return create_lazybuffer(a.device, ShapeTracker(a.shape), LoadOps, ast, max(a.dtype, b.dtype), {}) def backward(self, grad_output:LazyBuffer) -> Tuple[Optional[LazyBuffer], Optional[LazyBuffer]]: denom = (self.a.e(BinaryOps.MUL, self.a)).e(BinaryOps.ADD, self.b.e(BinaryOps.MUL, self.b)) return grad_output.e(BinaryOps.MUL, self.b.e(BinaryOps.DIV, denom)) if self.needs_input_grad[0] else None, \ diff --git a/test/test_symbolic_jit.py b/test/test_symbolic_jit.py index 8bae172575..5f5ec425f2 100644 --- a/test/test_symbolic_jit.py +++ b/test/test_symbolic_jit.py @@ -171,7 +171,7 @@ class TestSymbolicJit(unittest.TestCase): for i in range(1, 5): a = Tensor.rand(7, 11) symbolic = a.shrink(((3,5),(vi,vi+2))) - symbolic.lazydata.st.var_vals[vi] = i + symbolic.lazydata.var_vals[vi] = i symbolic = jf(symbolic).numpy() expected = f(a.shrink(((3,5),(i,i+2)))).numpy() np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6) diff --git a/test/test_symbolic_ops.py b/test/test_symbolic_ops.py index 17639b1a83..3b6b21e2ed 100644 --- a/test/test_symbolic_ops.py +++ b/test/test_symbolic_ops.py @@ -123,7 +123,7 @@ class TestSymbolicOps(unittest.TestCase): for i in range(1, 5): a = Tensor.rand(7, 11) symbolic = a.shrink(((3,5),(vi,vi+2))) - symbolic.lazydata.st.var_vals[vi] = i + symbolic.lazydata.var_vals[vi] = i symbolic = symbolic.numpy() expected = a.shrink(((3,5),(i,i+2))).numpy() np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6) diff --git a/test/test_symbolic_shapetracker.py b/test/test_symbolic_shapetracker.py index aa3e17b406..8eaf0d6df2 100644 --- a/test/test_symbolic_shapetracker.py +++ b/test/test_symbolic_shapetracker.py @@ -45,28 +45,28 @@ class TestSymbolicReshape(unittest.TestCase): for i in range(1, 6): t = Tensor.rand(i, 4).reshape(vi, 4) assert t.shape == (vi, 4) - assert t.lazydata.st.var_vals[vi] == i + assert t.lazydata.var_vals[vi] == i t = Tensor.rand(i, 6).reshape(vi, 2, 3) assert t.shape == (vi, 2, 3) - assert t.lazydata.st.var_vals[vi] == i + assert t.lazydata.var_vals[vi] == i def test_reshape_symbols_reshape_ints(self): vi = Variable("i", 1, 5) for i in range(1, 6): t = Tensor.rand(i, 4).reshape(vi, 4) assert t.shape == (vi, 4) - assert t.lazydata.st.var_vals == {vi: i} + assert t.lazydata.var_vals == {vi: i} t = t.reshape(i, 4) assert t.shape == (i, 4) - assert t.lazydata.st.var_vals == {} + assert t.lazydata.var_vals == {vi: i} def test_reshape_reuse_var_same_value_ok(self): vi = Variable("i", 1, 5) for i in range(1, 6): a = Tensor.rand(i, 4).reshape(vi, 4) b = Tensor.rand(i, 3).reshape(vi, 3) - assert a.lazydata.st.var_vals[vi] == i - assert b.lazydata.st.var_vals[vi] == i + assert a.lazydata.var_vals[vi] == i + assert b.lazydata.var_vals[vi] == i def test_reshape_reuse_var_different_value_ok(self): vi = Variable("i", 1, 10) @@ -74,8 +74,8 @@ class TestSymbolicReshape(unittest.TestCase): a = Tensor.rand(i, 4).reshape(vi, 2) b = Tensor.rand(i, 3).reshape(vi, 3) # a and b have different values of vi - assert a.lazydata.st.var_vals[vi] == 2 * i - assert b.lazydata.st.var_vals[vi] == i + assert a.lazydata.var_vals[vi] == 2 * i + assert b.lazydata.var_vals[vi] == i def test_reshape_into_symbols_bad_shape(self): vi = Variable("i", 1, 10) @@ -115,10 +115,10 @@ class TestSymbolicExpand(unittest.TestCase): vj = Variable("j", 1, 5) a = Tensor([[1], [2], [3]]).expand((3, vi)) assert a.shape == (3, vi) - assert a.lazydata.st.var_vals == {} + assert a.lazydata.var_vals == {} a = a.reshape(3, vi, 1).expand((3, vi, vj)) assert a.shape == (3, vi, vj) - assert a.lazydata.st.var_vals == {} + assert a.lazydata.var_vals == {} def test_plus_expands_constant(self): vi = Variable("i", 1, 5) @@ -152,18 +152,18 @@ class TestShapeTrackerVarVals(unittest.TestCase): vi = Variable("i", 1, 5) vj = Variable("j", 1, 5) t = Tensor.rand(3, 4).reshape(3, vi).reshape(4, vj) - assert t.lazydata.st.var_vals == {vi: 4, vj: 3} + assert t.lazydata.var_vals == {vi: 4, vj: 3} def test_lazy_check_var_vals(self): vi = Variable("i", 1, 5) a = Tensor.rand(3, 4).reshape(3, vi) b = Tensor.rand(5, 6).reshape(vi, 6) - assert a.lazydata.st.var_vals == {vi: 4} - assert b.lazydata.st.var_vals == {vi: 5} + assert a.lazydata.var_vals == {vi: 4} + assert b.lazydata.var_vals == {vi: 5} c = a@b - # shapetracker works with symbolic shape and doesn't check / propagate the underlying variable values + # shapetracker works with symbolic shape and doesn't check the underlying variable values assert c.shape == (3, 6) - assert c.lazydata.st.var_vals == {} + assert c.lazydata.var_vals == {vi: 4} if __name__ == '__main__': unittest.main() \ No newline at end of file diff --git a/tinygrad/codegen/linearizer.py b/tinygrad/codegen/linearizer.py index 29e41797ad..88262f2621 100644 --- a/tinygrad/codegen/linearizer.py +++ b/tinygrad/codegen/linearizer.py @@ -197,7 +197,7 @@ class Linearizer(OptimizedKernel): for i,b in enumerate(self.bufs): if b.realized in arg_bufs: self.buf_uops[i] = arg_bufs[b.realized] # add variables from symbolic shapes - for var in sorted(set(v for buf in self.ast.buffers for v in buf.st.var_vals), key=lambda k: k.key): + for var in sorted(set(v for buf in self.ast.buffers for v in buf.var_vals), key=lambda k: k.key): assert var.expr is not None self.loop_uops[var.expr] = self.uop(UOps.DEFINE_GLOBAL, dtypes.int32, (), (var.expr, dtypes._arg_int32)) # define local buffers diff --git a/tinygrad/jit.py b/tinygrad/jit.py index 92c916c7d9..30b0d13312 100644 --- a/tinygrad/jit.py +++ b/tinygrad/jit.py @@ -30,7 +30,7 @@ class TinyJit: assert len(set(input_rawbuffers.values())) == len(input_rawbuffers), "duplicate inputs to JIT" if self.cnt >= 2: try: var_vals: Dict[Variable, int] = kwargs["jit_ctx"] - except KeyError: var_vals = merge_dicts([arg.lazydata.st.var_vals for arg in args if arg.__class__ is Tensor]) + except KeyError: var_vals = merge_dicts([arg.lazydata.var_vals for arg in args if arg.__class__ is Tensor]) if len(var_vals) > 1: var_vals = dict(sorted(var_vals.items(), key=lambda kv: kv[0].key)) for (j,i),(input_name, expected_st, expected_type) in self.input_replace.items(): assert input_rawbuffers[input_name][0].dtype == expected_type, f"type mismatch in JIT, {input_rawbuffers[input_name][0].dtype} != {expected_type}" diff --git a/tinygrad/lazy.py b/tinygrad/lazy.py index a9498be19d..508ca3a0a5 100644 --- a/tinygrad/lazy.py +++ b/tinygrad/lazy.py @@ -5,10 +5,10 @@ from weakref import ref, WeakSet, WeakValueDictionary import numpy as np from tinygrad.graph import log_op -from tinygrad.helpers import GRAPH, DEBUG, prod, getenv, DType, dtypes, flatten, ImageDType +from tinygrad.helpers import GRAPH, DEBUG, prod, getenv, DType, dtypes, flatten, ImageDType, partition from tinygrad.ops import Device, Compiled, UnaryOps, BinaryOps, TernaryOps, ReduceOps, MovementOps, LoadOps, OpType, LazyOp from tinygrad.shape.shapetracker import ShapeTracker, View, get_contraction -from tinygrad.shape.symbolic import Node +from tinygrad.shape.symbolic import Node, Variable from tinygrad.runtime.lib import RawConst, RawBuffer, RawBufferMapped, RawBufferTransfer from tinygrad.runtime.ops_cpu import RawNumpyBuffer @@ -96,25 +96,27 @@ def get_movementroot(root:LazyBuffer, allow_contiguous=False) -> LazyBuffer: ret def get_movementroot_contiguous(x:LazyBuffer) -> LazyBuffer: return get_movementroot_contiguous(cast(LazyBuffer, x.op.src[0])) if not x.realized and x.op.op == LoadOps.CONTIGUOUS else (get_movementroot(x, True) if x.optype == MovementOps and x.st.contiguous else x) lazycache: WeakValueDictionary = WeakValueDictionary() -def create_lazybuffer(device:str, st:ShapeTracker, optype:OpType, op:LazyOp, dtype:DType): +def create_lazybuffer(device:str, st:ShapeTracker, optype:OpType, op:LazyOp, dtype:DType, var_vals:Dict[Variable,int]): # fromcpu aren't cached - if not LAZYCACHE or (optype is LoadOps and op.op in {LoadOps.EMPTY, LoadOps.RAND, LoadOps.CONST}): return LazyBuffer(device, st, optype, op, dtype) + if not LAZYCACHE or (optype is LoadOps and op.op in {LoadOps.EMPTY, LoadOps.RAND, LoadOps.CONST}): return LazyBuffer(device, st, optype, op, dtype, var_vals) # wop is the deduping key. i feel this used to compare more deeply - wop = (device, dtype, optype, ref(op)) + wop = (device, dtype, optype, ref(op), tuple(sorted(var_vals.keys()))) if wop in lazycache: for x in op.buffers: x.children.add(lazycache[wop]) return lazycache[wop] - lazycache[wop] = ret = LazyBuffer(device, st, optype, op, dtype) + lazycache[wop] = ret = LazyBuffer(device, st, optype, op, dtype, var_vals) return ret UNSAFE_PAD_OPS = {BinaryOps.DIV, BinaryOps.CMPLT, UnaryOps.LOG2, UnaryOps.EXP2, UnaryOps.RECIP} class LazyBuffer: __deletable__ = ('op',) - def __init__(self, device:str, st:ShapeTracker, optype:OpType, op:LazyOp, dtype:DType, src:Optional[RawBuffer]=None): + def __init__(self, device:str, st:ShapeTracker, optype:OpType, op:LazyOp, dtype:DType, var_vals:Dict[Variable,int], src:Optional[RawBuffer]=None): self.st: ShapeTracker = st # NOTE: this is not a copy! this should be a "read-only" ShapeTracker + self.var_vals: Dict[Variable, int] = var_vals + self.var_vals_key: Tuple[Variable, ...] = tuple(sorted(self.var_vals.keys())) self.device, self.shape, self.optype, self.dtype = device, self.st.shape, optype, dtype self.realized: Optional[RawBuffer] = src self.output_buffer: Optional[RawBuffer] = None # TODO: do we really need this? or can we just use realized @@ -132,8 +134,8 @@ class LazyBuffer: def __repr__(self): return f"" @property def key(self): - if self.realized: return (self.dtype, self.realized.key, self.st.key) - return (self.dtype, self.op.op, self.st.key) + if self.realized: return (self.dtype, self.realized.key, self.st.key, self.var_vals_key) + return (self.dtype, self.op.op, self.st.key, self.var_vals_key) def _device_extra_args(self) -> Dict[str, str]: return {"device": self.device.split(":", 1)[1]} if ":" in self.device else {} @@ -174,7 +176,7 @@ class LazyBuffer: @staticmethod def loadop(op, shape, dtype, device, arg=None, src=None) -> LazyBuffer: - return create_lazybuffer(device, ShapeTracker(tuple(shape)), LoadOps, LazyOp(op, tuple() if src is None else (src,), arg), dtype) + return create_lazybuffer(device, ShapeTracker(tuple(shape)), LoadOps, LazyOp(op, tuple() if src is None else (src,), arg), dtype, {}) # create a constant with the shape and dtype of self def const(self, val:Union[float, int]) -> LazyBuffer: @@ -183,11 +185,11 @@ class LazyBuffer: def contiguous(self:LazyBuffer) -> LazyBuffer: if not self.realized and self.op.op == LoadOps.CONTIGUOUS: return self # two CONTIGUOUS in a row is one - return create_lazybuffer(self.device, ShapeTracker(self.shape), LoadOps, LazyOp(LoadOps.CONTIGUOUS, (self,), None), self.dtype) + return create_lazybuffer(self.device, ShapeTracker(self.shape), LoadOps, LazyOp(LoadOps.CONTIGUOUS, (self,), None), self.dtype, self.var_vals) @staticmethod def fromCPU(x: np.ndarray) -> LazyBuffer: - return LazyBuffer("CPU", ShapeTracker(x.shape, [View(x.shape, tuple(st//x.itemsize for st in x.strides))]), LoadOps, LazyOp(LoadOps.EMPTY, (), None), dtypes.from_np(x.dtype), RawNumpyBuffer.fromCPU(x)) + return LazyBuffer("CPU", ShapeTracker(x.shape, [View(x.shape, tuple(st//x.itemsize for st in x.strides))]), LoadOps, LazyOp(LoadOps.EMPTY, (), None), dtypes.from_np(x.dtype), {}, RawNumpyBuffer.fromCPU(x)) def toCPU(self) -> np.ndarray: assert self.dtype.np, f"{self.dtype} is not supported in toCPU" @@ -220,7 +222,7 @@ class LazyBuffer: # remove the buffers from any (childless) BinaryOps that feed into this srcs = tuple([x.op if x.optype == BinaryOps and not x.children and not x.realized else x for x in srcs]) # type: ignore - return create_lazybuffer(out_device, ShapeTracker(out_shape), BinaryOps, LazyOp(op, srcs, arg), out_dtype) + return create_lazybuffer(out_device, ShapeTracker(out_shape), BinaryOps, LazyOp(op, srcs, arg), out_dtype, self.var_vals) def shuffle_and_prune_movement_ops(self, st: ShapeTracker, op: MovementOps, arg: Union[Tuple[Union[Node,int], ...], Tuple[Tuple[int, int], ...]]) -> LazyBuffer: if SHUFFLE_MOVEMENT_OPS and self.optype == BinaryOps and not self.realized and (op in {MovementOps.SHRINK, MovementOps.STRIDE, MovementOps.PERMUTE} or (op == MovementOps.RESHAPE and self.op.op in UnaryOps)) and not self.children: @@ -230,12 +232,12 @@ class LazyBuffer: root = get_movementroot(self) if root.st.contiguous and root != self and prod(st.shape) == prod(root.shape): return root.reshape(st.shape) - return create_lazybuffer(self.device, st, MovementOps, LazyOp(op, (self,), arg), self.dtype) + return create_lazybuffer(self.device, st, MovementOps, LazyOp(op, (self,), arg), self.dtype, self.var_vals) def _reduce_op(self:LazyBuffer, op:ReduceOps, new_shape:Tuple[int, ...]) -> LazyBuffer: if self.shape == tuple(new_shape): return self srcs = _push_movement_ops((self,)) if SHUFFLE_MOVEMENT_OPS else (self,) - return create_lazybuffer(self.device, ShapeTracker(new_shape), ReduceOps, LazyOp(op, srcs, new_shape), self.dtype) + return create_lazybuffer(self.device, ShapeTracker(new_shape), ReduceOps, LazyOp(op, srcs, new_shape), self.dtype, self.var_vals) def reduce_op(self:LazyBuffer, op:ReduceOps, new_shape:Tuple[int, ...]) -> LazyBuffer: if any(not isinstance(s, int) for s in self.shape) or prod(self.shape) // prod(new_shape) < 32768: return self._reduce_op(op, new_shape) # The amount of work should be big enough to take the benefit of "2 kernels" approach. @@ -246,8 +248,19 @@ class LazyBuffer: def reshape(self:LazyBuffer, arg:Tuple[Union[Node, int], ...]) -> LazyBuffer: if self.shape == arg: return self + new_ints, new_nodes = partition(arg, lambda s: isinstance(s, int)) + if new_nodes and all(isinstance(s, int) for s in self.shape): + # reshape from all int shape into shape with a variable, update the variable value + assert len(new_nodes) == 1 and isinstance(new_nodes[0], Variable), "only support adding one Variable to the int shape" + new_var, new_val = new_nodes[0], prod(self.shape) // prod(new_ints) + if new_var not in self.var_vals: + assert new_var.min <= new_val <= new_var.max, f"variable value {new_val} out of range [{new_var.min}, {new_var.max}]" + self.var_vals[new_var] = new_val + else: assert self.var_vals[new_var] == new_val, f"value conflicts, was {self.var_vals[new_var]}, set to {new_val}" if not self.realized and self.op.op == MovementOps.RESHAPE: + assert isinstance(self.op.src[0], LazyBuffer) self.op.src[0].children.discard(self) # NOTE: this is only required in reshape and when pushing permutes, why?? + self.op.src[0].var_vals = self.var_vals return self.op.src[0].reshape(arg) return self.shuffle_and_prune_movement_ops(ShapeTracker(self.st).reshape(arg), MovementOps.RESHAPE, arg) diff --git a/tinygrad/ops.py b/tinygrad/ops.py index 51e237daad..3dcf9c4056 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -198,7 +198,7 @@ class Compiled: from tinygrad.jit import CacheCollector CacheCollector._mark_output_buffer(output.output_buffer) # update the output var_vals from src - output.st.var_vals = dict(sorted(merge_dicts([buf.st.var_vals for buf in ast.buffers]).items(), key=lambda kv:cast(Variable,kv[0]).key)) + output.var_vals = dict(sorted(merge_dicts([buf.var_vals for buf in ast.buffers]).items(), key=lambda kv:cast(Variable,kv[0]).key)) from tinygrad.codegen.linearizer import Linearizer k = Linearizer(ast, output, self.linearizer_opts) @@ -218,5 +218,5 @@ class Compiled: if prg.name == getenv("PRINT_PRG", ''): print(prg.prg) - prg.exec(k.bufs, var_vals=output.st.var_vals) + prg.exec(k.bufs, var_vals=output.var_vals) return output.realized diff --git a/tinygrad/shape/shapetracker.py b/tinygrad/shape/shapetracker.py index 7dd95eb55b..e038ed1783 100644 --- a/tinygrad/shape/shapetracker.py +++ b/tinygrad/shape/shapetracker.py @@ -1,8 +1,8 @@ # ShapeTracker allows movement operations to a buffer that don't require a copy to be made. from __future__ import annotations import functools -from typing import Dict, Tuple, Union, List, Optional, NamedTuple -from tinygrad.helpers import prod, DEBUG, partition +from typing import Tuple, Union, List, Optional, NamedTuple +from tinygrad.helpers import prod, DEBUG from tinygrad.shape.symbolic import Variable, MulNode, NumNode, Node, SumNode, is_sym_int @functools.lru_cache(maxsize=None) @@ -127,11 +127,10 @@ def get_unsafe_resize_offset(strides, arg): return sum([s * x[0] for s, x in zip(strides,arg)]) class ShapeTracker: - __slots__ = "views", "var_vals" + __slots__ = "views" def __init__(self, shape:Union[ShapeTracker, Tuple[Union[Node,int], ...]], views:Optional[List[View]]=None): self.views: List[View] = views if views is not None else [*shape.views] if isinstance(shape, ShapeTracker) else [View(shape)] - self.var_vals: Dict[Variable, int] = shape.var_vals if isinstance(shape, ShapeTracker) else {} - def __repr__(self): return f"ShapeTracker(shape={self.views[-1].shape}, views={self.views}, var_vals={self.var_vals})" + def __repr__(self): return f"ShapeTracker(shape={self.views[-1].shape}, views={self.views})" def copy(self) -> ShapeTracker: return ShapeTracker(self.views[-1].shape, [*self.views]) @property @@ -141,7 +140,7 @@ class ShapeTracker: def shape(self) -> Tuple[int, ...]: return self.views[-1].shape # NOTE: real type is Tuple[Union[Node, int], ...] but mypy complains about prod(shape) @property - def key(self) -> Tuple[Tuple[View, ...], Tuple[Variable, ...]]: return tuple(self.views), tuple(sorted(self.var_vals.keys())) + def key(self) -> Tuple[View, ...]: return tuple(self.views) # this is the real size (ish) def size(self): return prod([s for s,st in zip(self.views[-1].shape, self.views[-1].strides) if st != 0]) @@ -233,16 +232,6 @@ class ShapeTracker: def reshape(self, new_shape: Tuple[Union[Node,int], ...]): if self.views[-1].shape == new_shape: return self - new_ints, new_nodes = partition(new_shape, lambda s: isinstance(s, int)) - if new_nodes and all(isinstance(s, int) for s in self.shape): - # reshape from all int shape into shape with a variable, update the variable value - assert len(new_nodes) == 1 and isinstance(new_nodes[0], Variable), "only support adding one Variable to the int shape" - new_var, new_val = new_nodes[0], prod(self.shape) // prod(new_ints) - if new_var not in self.var_vals: - assert new_var.min <= new_val <= new_var.max, f"variable value {new_val} out of range [{new_var.min}, {new_var.max}]" - self.var_vals[new_var] = new_val - else: assert self.var_vals[new_var] == new_val, f"value conflicts, was {self.var_vals[new_var]}, set to {new_val}" - elif not new_nodes: self.var_vals = {} assert all(is_sym_int(x) and x > 0 for x in new_shape), f"shape must be symbolic ints and can't contain 0 or negative numbers {new_shape}" # only check size for int shapes. we don't check symbolic here as long as the reshape itself can be done if all(isinstance(s, int) for s in self.shape) and all(isinstance(s, int) for s in new_shape):