diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index e9e2512fb0..fbf03cd94d 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -165,7 +165,7 @@ jobs: FORWARD_ONLY=1 GPU=1 IMAGE=2 python3 test/test_ops.py - name: Test openpilot model run: | - ALLOWED_KERNEL_COUNT=197 FLOAT16=1 VALIDHACKS=1 DEBUGCL=1 GPU=1 IMAGE=2 python3 openpilot/compile.py + ALLOWED_KERNEL_COUNT=199 FLOAT16=1 VALIDHACKS=1 DEBUGCL=1 GPU=1 IMAGE=2 python3 openpilot/compile.py python3 -c 'import os; assert os.path.getsize("/tmp/output.thneed") < 100_000_000' DEBUGCL=1 GPU=1 IMAGE=2 python3 openpilot/compile.py VALIDHACKS=1 DEBUGCL=1 GPU=1 IMAGE=2 python3 openpilot/compile.py diff --git a/examples/llama.py b/examples/llama.py index 1340057c25..0069b6849b 100755 --- a/examples/llama.py +++ b/examples/llama.py @@ -216,6 +216,7 @@ if __name__ == "__main__": parser.add_argument('--timing', action='store_true', help="Print timing per token") parser.add_argument('--profile', action='store_true', help="Output profile data to out.prof") parser.add_argument('--large', action='store_true', help="Use the 13B model instead of the 7B one") + parser.add_argument('--tinyfake', action='store_true', help="Use the fake very small model") args = parser.parse_args() chatbot = args.prompt == None @@ -257,6 +258,11 @@ if __name__ == "__main__": del weights0 del weights1 + elif args.tinyfake: + # GRAPH=1 python3 examples/llama.py --timing --prompt "Hello." --temperature=0 --tinyfake --count 1 + model = Transformer(**args_small) + from tinygrad.nn.optim import get_parameters + for p in get_parameters(model): p.assign(np.zeros(p.shape, dtype=p.dtype.np)) else: model = Transformer(**args_7B) with Timing("loaded weights in ", lambda et_ns: f", {GlobalCounters.mem_used/1e9:.2f} GB loaded at {GlobalCounters.mem_used/et_ns:.2f} GB/s"): diff --git a/extra/utils.py b/extra/utils.py index ef007e7fe2..04d4e31e47 100644 --- a/extra/utils.py +++ b/extra/utils.py @@ -109,7 +109,7 @@ def load_single_weight(t:Tensor, myfile, shape, strides, dtype, storage_offset, # this needs real APIs if t.device in ["METAL", "CLANG", "LLVM"]: del t.lazydata.op - t.lazydata.realized = t.lazydata.dbuffer.buffer(prod(t.shape), dtype=t.dtype) + t.lazydata.realized = Device[t.lazydata.device].buffer(prod(t.shape), dtype=t.dtype) myfile.readinto(t.lazydata.realized._buffer()) else: def _mmap(lna): diff --git a/openpilot/compile.py b/openpilot/compile.py index b7a66e1c79..31c0c66f55 100644 --- a/openpilot/compile.py +++ b/openpilot/compile.py @@ -49,7 +49,7 @@ def compile(dat, output_fn): Tensor.manual_seed(1337) Tensor.no_grad = True using_graph = graph.GRAPH - if getenv("GRAPH") < 2: graph.GRAPH = False + if getenv("GRAPH") < 3: graph.GRAPH = False onnx_model = onnx.load(io.BytesIO(dat)) run_onnx = get_run_onnx(onnx_model) diff --git a/test/external/external_test_opt.py b/test/external/external_test_opt.py index 061b832da6..772852a411 100644 --- a/test/external/external_test_opt.py +++ b/test/external/external_test_opt.py @@ -14,16 +14,130 @@ from tinygrad.ops import GlobalCounters, MovementOps, ReduceOps from tinygrad.lazy import PUSH_PERMUTES class CLCache(): + def __init__(self, allowed=None, strict=False, preclear=True): self.allowed, self.strict, self.preclear = allowed, strict, preclear def __enter__(self): - gc.collect() - for x in [x for x in gc.get_objects() if isinstance(x, Tensor)]: - x.realize() + if self.preclear: + gc.collect() + for x in [x for x in gc.get_objects() if isinstance(x, Tensor)]: + x.realize() + GlobalCounters.reset() GlobalCounters.cache = [] print("cache: entering") def __exit__(self, type, value, traceback): - print(f"cache: exiting with size {len(GlobalCounters.cache)}") + print(f"cache: exiting with size {len(GlobalCounters.cache)}", f"allowed {self.allowed}" if self.allowed is not None else "") + if self.allowed is not None: + assert len(GlobalCounters.cache) <= self.allowed and (not self.strict or len(GlobalCounters.cache) == self.allowed), "used too many kernels!" GlobalCounters.cache = None +from models.convnext import ConvNeXt +from models.efficientnet import EfficientNet +from models.resnet import ResNet18 +from models.vit import ViT +from tinygrad.nn.optim import get_parameters + +@unittest.skipUnless(Device.DEFAULT == "GPU", "Not Implemented") +class TestInferenceMinKernels(unittest.TestCase): + def setUp(self): + Tensor.training = False + + @unittest.skipIf(not PUSH_PERMUTES, "this test requires PUSH_PERMUTES") + def test_convnext(self): + model = ConvNeXt() + for p in get_parameters(model): p.assign(np.zeros(p.shape, dtype=p.dtype.np)) + img = Tensor.randn(1, 3, 224, 224) + with CLCache(129): + model(img).realize() + + def test_enet(self): + model = EfficientNet(has_se=False) + for p in get_parameters(model): p.assign(np.zeros(p.shape, dtype=p.dtype.np)) + img = Tensor.randn(1, 3, 224, 224) + with CLCache(51): + model.forward(img).realize() + + def test_resnet(self): + model = ResNet18() + for p in get_parameters(model): p.assign(np.zeros(p.shape, dtype=p.dtype.np)) + img = Tensor.randn(1, 3, 224, 224) + with CLCache(31): # NOTE: this should be 4 lower + model.forward(img).realize() + + def test_vit(self): + model = ViT(embed_dim=192, num_heads=3) + for p in get_parameters(model): p.assign(np.zeros(p.shape, dtype=p.dtype.np)) + img = Tensor.randn(1, 3, 224, 224) + with CLCache(223): # NOTE: this is way too high + out = model.forward(img) + assert len(GlobalCounters.cache) == 0, f"ViT prerealized?" + out.realize() + + def test_llama(self): + from examples.llama import Transformer, onehot_encode + args_tiny = {"dim": 512, "multiple_of": 256, "n_heads": 8, "n_layers": 4, "norm_eps": 1e-05, "vocab_size": 1000} + model = Transformer(**args_tiny) + for p in get_parameters(model): p.assign(np.zeros(p.shape, dtype=p.dtype.np)) + with CLCache(85): + model(onehot_encode([1,2,3,4], vocab_size=args_tiny['vocab_size']), 0).realize() + +@unittest.skipUnless(Device.DEFAULT == "GPU", "Not Implemented") +class TestOptBinOp(unittest.TestCase): + def _test_no_binop_rerun(self, f1, f2=None, allowed=1): + a = Tensor.randn(16, 16) + b = Tensor.randn(16, 16) + with CLCache(): + c = f1(a, b) + if f2 is not None: d = f2(a, b) + c.realize() + if f2 is not None: d.realize() + assert len(GlobalCounters.cache) == allowed, "binop was rerun!" + if f2 is not None: np.testing.assert_allclose(c.numpy().ravel(), d.numpy().ravel(), rtol=1e-3, atol=1e-5) + + def test_no_binop_rerun(self): return self._test_no_binop_rerun(lambda a,b: a*b, lambda a,b: (a*b).reshape(16, 16, 1)) + def test_no_binop_rerun_alt(self): return self._test_no_binop_rerun(lambda a,b: (a*b).reshape(16, 16, 1), lambda a,b: a*b) + def test_no_binop_rerun_reduce_broadcast(self): return self._test_no_binop_rerun(lambda a,b: a.sum()+b, lambda a,b: a.sum().reshape(1,1)+b, allowed=2) + def test_no_binop_rerun_transposed(self): return self._test_no_binop_rerun(lambda a,b: (a.T*b.T).T, lambda a,b: a*b) + def test_no_binop_rerun_mid_reshape(self): return self._test_no_binop_rerun(lambda a,b: (a*b).reshape(256)+a.reshape(256)) + + # currently non working tests + #def test_no_binop_rerun_preshape(self): return self._test_no_binop_rerun(lambda a,b: a.reshape(16, 16, 1)*b.reshape(16, 16, 1), lambda a,b: a*b) + #def test_no_binop_rerun_reduce(self): return self._test_no_binop_rerun(lambda a,b: (a*b).sum(), lambda a,b: (a*b).reshape(16, 16, 1).sum()) + #def test_no_binop_rerun_reduce_alt(self): return self._test_no_binop_rerun(lambda a,b: a.sum(1)+b[0], lambda a,b: a.sum(1).reshape(1,16)+b[0]) + +@unittest.skip("elementwise with >1 reduce inputs currently don't fuse") +@unittest.skipUnless(Device.DEFAULT == "GPU", "Not Implemented") +class TestOptReduceLoop(unittest.TestCase): + def test_loop_left(self): + a = Tensor.randn(16, 16) + b = Tensor.randn(16, 16) + with CLCache(): + t = a.sum(0) + b = t.reshape(16,1).expand(16,16).sum(0) + c = (t+b) + c.realize() + assert len(GlobalCounters.cache) == 2, "loop left fusion broken" + + def test_loop_right(self): + a = Tensor.randn(16, 16) + b = Tensor.randn(16, 16) + with CLCache(): + t = a.sum(0) + b = t.reshape(16,1).expand(16,16).sum(0) + c = (b+t) + c.realize() + assert len(GlobalCounters.cache) == 2, "loop right fusion broken" + +@unittest.skipUnless(Device.DEFAULT == "GPU", "Not Implemented") +class TestOptWChild(unittest.TestCase): + def test_unrealized_child(self): + a = Tensor.randn(16, 16) + b = Tensor.randn(16, 16) + with CLCache(): + c = (a*b).sum() + d = c+1 + e = c+2 + d.realize() + assert len(GlobalCounters.cache) == 2, "don't fuse if you have children" + @unittest.skipUnless(Device.DEFAULT == "GPU", "Not Implemented") class TestOpt(unittest.TestCase): def test_muladd(self): @@ -51,7 +165,7 @@ class TestOpt(unittest.TestCase): with CLCache(): img_bn = bn(img).realize() print(img_bn) - assert len(GlobalCounters.cache) == 3, "optimizer didn't fold batchnorm" + assert len(GlobalCounters.cache) == 3, f"optimizer didn't fold batchnorm, got {len(GlobalCounters.cache)}" Tensor.training = False def test_fold_conv_sgd(self): @@ -94,7 +208,7 @@ class TestOpt(unittest.TestCase): img_conv = bn(c1(img)).relu().realize() with CLCache(): img_conv = bn(c1(img)).relu().realize() - assert len(GlobalCounters.cache) == 1, "optimizer didn't fold conv-batchnorm at test time" + assert len(GlobalCounters.cache) == 1, f"optimizer didn't fold conv-batchnorm at test time, got {len(GlobalCounters.cache)}" def test_fold_conv_batchnorm(self): Tensor.training = True @@ -104,7 +218,7 @@ class TestOpt(unittest.TestCase): with CLCache(): img_conv = bn(c1(img)).relu().realize() print(img_conv) - assert len(GlobalCounters.cache) == 4, "optimizer didn't fold conv-batchnorm" + assert len(GlobalCounters.cache) == 4, f"optimizer didn't fold conv-batchnorm, got {len(GlobalCounters.cache)}" Tensor.training = False def test_fold_conv_elu(self): @@ -164,9 +278,10 @@ class TestOpt(unittest.TestCase): np.testing.assert_allclose(a.numpy().sum(-1).reshape(16,1,16).transpose(2,1,0), d.numpy(), rtol=1e-3, atol=1e-5) if PUSH_PERMUTES: assert cache_len == 1, "permute wasn't pushed!" + # TODO: push permute through expansion reshape @unittest.skip("expansion can't push expand permute yet") + @unittest.skipIf(not PUSH_PERMUTES, "this test requires PUSH_PERMUTES") def test_permute_was_pushed_through_expand_reshape(self): - if not PUSH_PERMUTES: return a = Tensor.randn(16, 16, 16) with CLCache(): c = a.sum(2) @@ -176,33 +291,8 @@ class TestOpt(unittest.TestCase): np.testing.assert_allclose(a.numpy().sum(2).transpose(1,0).reshape(4,4,4,4), d.numpy(), rtol=1e-3, atol=1e-5) if PUSH_PERMUTES: assert cache_len == 1, "permute wasn't pushed!" - @unittest.skip("this is broken") - def test_no_binop_rerun(self): - a = Tensor.randn(16, 16) - b = Tensor.randn(16, 16) - with CLCache(): - c = a*b - d = (a*b).reshape(16, 16, 1) - c.realize() - d.realize() - assert len(GlobalCounters.cache) == 1, "binop was rerun!" - np.testing.assert_allclose(c.numpy(), d.numpy(), rtol=1e-3, atol=1e-5) - - @unittest.skip("this is broken") - def test_no_binop_rerun_alt(self): - a = Tensor.randn(16, 16) - b = Tensor.randn(16, 16) - with CLCache(): - c = (a*b).reshape(16, 16, 1) - d = a*b - c.realize() - d.realize() - assert len(GlobalCounters.cache) == 1, "binop was rerun!" - np.testing.assert_allclose(c.numpy(), d.numpy(), rtol=1e-3, atol=1e-5) - - # TODO: should be okay with PUSH_PERMUTES + @unittest.skipIf(PUSH_PERMUTES, "this test is broken with PUSH_PERMUTES") def test_no_reduceop_rerun(self): - if PUSH_PERMUTES: return a = Tensor.randn(16, 16, 16) with CLCache(): c = a.sum(2) @@ -213,9 +303,8 @@ class TestOpt(unittest.TestCase): np.testing.assert_allclose(c.numpy().transpose(1,0), d.numpy(), rtol=1e-3, atol=1e-5) assert cache_len == 1, "reduceop was rerun!" - # TODO: should be okay with PUSH_PERMUTES + @unittest.skipIf(PUSH_PERMUTES, "this test is brokem with PUSH_PERMUTES") def test_no_reduceop_rerun_alt(self): - if PUSH_PERMUTES: return a = Tensor.randn(16, 16, 16) with CLCache(): c = a.sum(2).permute(1,0) diff --git a/test/test_custom_function.py b/test/test_custom_function.py index 8b01fa7780..72aa7ba575 100644 --- a/test/test_custom_function.py +++ b/test/test_custom_function.py @@ -8,7 +8,7 @@ from tinygrad.helpers import prod, dtypes # *** first, we implement the atan2 op at the lowest level *** # `atan2_gpu` for GPUBuffers and `atan2_cpu` for CPUBuffers -from tinygrad.lazy import LazyBuffer, Device +from tinygrad.lazy import LazyBuffer, create_lazybuffer, Device from tinygrad.ops import ASTRunner # we don't always have GPU support, so the type signature is the abstract CompiledBuffer instead of GPUBuffer @@ -39,7 +39,7 @@ class ATan2(Function): assert prod(a.shape) == prod(b.shape) and a.device == b.device, "shape or device mismatch" self.a, self.b = a, b ast = LazyOp(LoadOps.CUSTOM, (a.contiguous(), b.contiguous()), {"GPU": atan2_gpu, "CPU": atan2_cpu}[a.device]) - return LazyBuffer(a.device, a.shape, LoadOps, ast, max(a.dtype, b.dtype)) + return create_lazybuffer(a.device, a.shape, LoadOps, ast, max(a.dtype, b.dtype)) def backward(self, grad_output:LazyBuffer) -> Tuple[Optional[LazyBuffer], Optional[LazyBuffer]]: denom = (self.a.binary_op(BinaryOps.MUL, self.a)).binary_op(BinaryOps.ADD, self.b.binary_op(BinaryOps.MUL, self.b)) return grad_output.binary_op(BinaryOps.MUL, self.b.binary_op(BinaryOps.DIV, denom)) if self.needs_input_grad[0] else None, \ diff --git a/test/test_ops.py b/test/test_ops.py index a75558c25a..3aa70434b6 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -197,6 +197,8 @@ class TestOps(unittest.TestCase): [[1.0,1.0,0.0,1.0]], ]) helper_test_op([(3,4,5,6)], lambda x: x.max(axis=1)[0], lambda x: Tensor.max(x, axis=1)) + def test_mean(self): + helper_test_op([(3,4,5,6)], lambda x: x.mean()) def test_mean_axis(self): helper_test_op([(3,4,5,6)], lambda x: x.mean(axis=(1,2)), lambda x: Tensor.mean(x, axis=(1,2))) def test_log_softmax(self): diff --git a/test/unit/test_flopcounter.py b/test/unit/test_flopcounter.py index b96708a459..ba6ac0fa74 100644 --- a/test/unit/test_flopcounter.py +++ b/test/unit/test_flopcounter.py @@ -1,7 +1,7 @@ #!/usr/bin/env python import unittest from typing import NamedTuple, Tuple -from tinygrad.ops import LazyOp, BinaryOps, get_lazyop_info +from tinygrad.ops import LazyOp, BinaryOps, ReduceOps, get_lazyop_info from tinygrad.helpers import DType, dtypes class TestBuffer(NamedTuple): @@ -37,5 +37,12 @@ class TestFlopCounter(unittest.TestCase): info = get_lazyop_info(op2) self.assertEqual(info.flops, 12) + def test_flops_red(self): + op0 = LazyOp(BinaryOps.MUL, (self.buf0,self.buf1,), None) + op1 = LazyOp(ReduceOps.SUM, (op0,), (1,)) + op2 = LazyOp(BinaryOps.ADD, (op1, op1,), None) + info = get_lazyop_info(op2) + self.assertEqual(info.flops, 9) + if __name__ == '__main__': unittest.main() diff --git a/tinygrad/graph.py b/tinygrad/graph.py index 46b2dcbae1..b3173f246e 100644 --- a/tinygrad/graph.py +++ b/tinygrad/graph.py @@ -42,14 +42,11 @@ def str_dtype(dtyp): ret = str(dtyp)[7:] return "" if ret == 'float' else f"\n{ret}" -def log_op(ret: LazyBuffer, ast: LazyOp, show_graph: Optional[bool] = None): +def log_op(ret: LazyBuffer, ast: LazyOp, show_graph: Optional[bool] = None, phantom=False): if show_graph is None: show_graph = bool(GRAPH) if not DEBUG and not show_graph: return op: List[Op] = [x.op for x in get_lazyops(ast)] inp: List[LazyBuffer] = get_buffers(ast) - if len(inp) == 1 and inp[0] == ret: - if show_graph and nm(ret) in G.nodes: G.nodes[nm(ret)]['style'] += ', bold' - return # don't log self loops oporder = [LoadOps, FusedOps, ReduceOps, BinaryOps, UnaryOps, MovementOps] optype = type(sorted(op, key=lambda x: oporder.index(type(x)))[0]) cnts[optype] += 1 @@ -59,14 +56,15 @@ def log_op(ret: LazyBuffer, ast: LazyOp, show_graph: Optional[bool] = None): dashed = (optype == LoadOps and hasattr(ret, "_backing")) or (hasattr(ret, "st") and not ret.st.contiguous) # type: ignore for x in inp: - G.add_edge(nm(x), nm(ret), label=get_sop(op)) + G.add_edge(nm(x), nm(ret), label=get_sop(op), color='#00000060' if phantom else 'black') if 'label' not in G.nodes[nm(x)]: G.nodes[nm(x)]['label'] = str(x.shape)+str_dtype(ret.dtype) if nm(ret) not in G.nodes: G.add_node(nm(ret)) G.nodes[nm(ret)]['label'] = (str(set(x.shape for x in inp))+"\n"+str(ret.shape) if optype == ReduceOps else str(ret.shape))+str_dtype(ret.dtype) - G.nodes[nm(ret)]['fillcolor'] = (top_colors[optype] + ('80' if dashed else str())) if optype in top_colors else "#ffffff" - G.nodes[nm(ret)]['style'] = 'filled, dashed' if dashed else 'filled' + G.nodes[nm(ret)]['fillcolor'] = (top_colors[optype] + ('60' if phantom else ('80' if dashed else str()))) if optype in top_colors else "#ffffff" + G.nodes[nm(ret)]['color'] = 'white' if phantom else 'black' + G.nodes[nm(ret)]['style'] = ('filled, dashed' if dashed else 'filled') G.nodes[nm(ret)]['prunable'] = optype in [LoadOps, MovementOps] # prune movementops and loadops diff --git a/tinygrad/lazy.py b/tinygrad/lazy.py index 0ccf98e31a..2d511c0dc0 100644 --- a/tinygrad/lazy.py +++ b/tinygrad/lazy.py @@ -1,10 +1,10 @@ from __future__ import annotations -from typing import Optional, Tuple, Union, List, Dict, Any, ClassVar, cast +from typing import Optional, Tuple, Union, List, Dict, Any, cast import sys, weakref, importlib, inspect, functools, pathlib from weakref import WeakValueDictionary from tinygrad.helpers import prod, getenv, DType, dtypes, LazyNumpyArray, flatten, ImageDType from tinygrad.shape.shapetracker import ShapeTracker, get_contraction -from tinygrad.ops import Compiled, Interpreted, UnaryOps, BinaryOps, ReduceOps, MovementOps, LoadOps, OpType, LazyOp, get_buffers, get_lazyops, map_buffers +from tinygrad.ops import Compiled, Interpreted, UnaryOps, BinaryOps, ReduceOps, MovementOps, LoadOps, OpType, LazyOp, get_lazyops, get_buffers, map_buffers from tinygrad.runtime.lib import RawConst, RawBuffer # lazy can recurse a lot @@ -22,10 +22,9 @@ class _Device: Device = _Device() # TODO: movement ops that only change shape are really nops. treat them as such -REMOVE_MOVEMENT_NOPS, MERGE_UNARY_OPS, MERGE_ELEMENTWISE_INTO_REDUCE, SHUFFLE_MOVEMENT_OPS = OPT>=1, OPT>=1, OPT>=1, OPT>=1 -MERGE_ELEMENTWISE_OPS, MERGE_ONE_REDUCE_INTO_ELEMENTWISE = OPT>=2, OPT>=2 +REMOVE_MOVEMENT_NOPS, MERGE_ELEMENTWISE_INTO_REDUCE, SHUFFLE_MOVEMENT_OPS, MERGE_ELEMENTWISE_OPS = OPT>=1, OPT>=1, OPT>=1, OPT>=1 +MERGE_ONE_REDUCE_INTO_ELEMENTWISE, SHUFFLE_PAD_OPS = OPT>=2, OPT>=2 # shuffle pad ops is fine now since we only push to merge binops PUSH_PERMUTES, PUSH_CONTIGUOUS = OPT>=3, OPT>=3 -SHUFFLE_PAD_OPS = OPT>=4 # no longer makes wrong outputs since div isn't allowed, but still unadvisable # **** realize functions **** def _ast_reduceops(self:LazyBuffer) -> LazyOp: @@ -42,15 +41,16 @@ def _ast_binaryops(self:LazyBuffer) -> LazyOp: psrcs: List[Tuple[LazyBuffer, LazyBuffer]] = [(k,x) for k,x in zip(real_srcs.keys(), map(get_movementroot_contiguous, real_srcs.keys())) if x.optype == ReduceOps and x.realized is None and prod(k.shape) == prod(x.shape) and len(x.children) <= 1 and len(k.children) <= 1] intermediate_shape: Tuple[int, ...] = self.shape if len(psrcs) == 1 and MERGE_ONE_REDUCE_INTO_ELEMENTWISE: - if psrcs[0][1].optype == ReduceOps: - top = _ast_reduceops(psrcs[0][1]) - real_srcs[psrcs[0][0]] = top + psrc = psrcs[0] # NOTE: right now we can't handle multiple, as we'd have to check for loop + if psrc[1].optype == ReduceOps: + top = _ast_reduceops(psrc[1]) + real_srcs[psrc[0]] = top real_srcs.update({x:x for x in get_buffers(top)}) # the reduce op buffers are not modified # if the ReduceOp is followed by a reshape, we push this reshape before all the ElementwiseOp inputs - if psrcs[0][0].shape != psrcs[0][1].shape: - intermediate_shape = psrcs[0][1].shape - assert psrcs[0][0].shape == self.shape, f"shape mismatch {psrcs[0][0].shape} != {self.shape}" + if psrc[0].shape != psrc[1].shape: + intermediate_shape = psrc[1].shape + assert psrc[0].shape == self.shape, f"shape mismatch {psrc[0].shape} != {self.shape}" # reshape all the late ops into the output shape # NOTE: these RESHAPEs will return self if they don't change the shape @@ -66,41 +66,48 @@ def get_single_root(root:LazyBuffer) -> LazyBuffer: return get_single_root(root. def get_movementroot(root:LazyBuffer, allow_contiguous=False) -> LazyBuffer: return get_movementroot(root.op.src[0], allow_contiguous) if root.realized is None and (root.optype == MovementOps or (root.op.op == LoadOps.CONTIGUOUS and allow_contiguous and root.op.src[0].st.contiguous)) else root def get_movementroot_contiguous(x:LazyBuffer) -> LazyBuffer: return get_movementroot_contiguous(x.op.src[0]) if x.realized is None and x.op.op == LoadOps.CONTIGUOUS else (get_movementroot(x, True) if x.optype == MovementOps and x.st.contiguous else x) -def replace_with_movement_op(y:Union[LazyOp, LazyBuffer], op:MovementOps, arg:Tuple[Any, ...]) -> LazyBuffer: - if isinstance(y, LazyBuffer): return y.movement_op(op, arg) +def replace_with_movement_ops(y:Union[LazyOp, LazyBuffer], ops:List[Tuple[MovementOps, Tuple[Any, ...]]]) -> LazyBuffer: + if isinstance(y, LazyBuffer): + for op, arg in ops: y = y.movement_op(op, arg) + return y assert y.op in BinaryOps or y.op in UnaryOps - return elementwise_op(y.op, *[replace_with_movement_op(z, op, arg) for z in y.src], arg=y.arg) # type: ignore + return elementwise_op(y.op, *[replace_with_movement_ops(z, ops) for z in y.src], arg=y.arg) # type: ignore + +lazycache: WeakValueDictionary[Tuple[str, DType, OpType, LazyOp], LazyBuffer] = WeakValueDictionary() +def create_lazybuffer(device:str, shape:Union[ShapeTracker, Tuple[int, ...]], optype:OpType, op:LazyOp, dtype:DType): + st = shape if isinstance(shape, ShapeTracker) else ShapeTracker(tuple(shape)) + + # fromcpu aren't cached + if optype == LoadOps and op.op == LoadOps.FROMCPU: return LazyBuffer(device, st, optype, op, dtype) + + #print("create_lazybuffer", device, shape, optype, op, dtype) + + # NOTE: shape should be deterministic. annoying to cache with the ShapeTracker + # get_weakop makes all the LazyBuffers in the op have a weakref + wop = (device, dtype, optype, get_weakop(op)) + + if wop not in lazycache: lazycache[wop] = ret = LazyBuffer(device, st, optype, op, dtype) + else: ret = lazycache[wop] + return ret -def support_weakref(x): return x -@support_weakref # needed for mypyc, this prevents LazyBuffer from becoming a native class class LazyBuffer: __deletable__ = ('op',) - lazycache: ClassVar[WeakValueDictionary[Tuple[str, DType, OpType, LazyOp], LazyBuffer]] = WeakValueDictionary() - def __new__(cls, device:str, shape:Union[ShapeTracker, Tuple[int, ...]], optype:OpType, op:LazyOp, dtype:DType): - # fromcpu aren't cached - if optype == LoadOps and op.op == LoadOps.FROMCPU: - return super().__new__(cls) - wop = (device, dtype, optype, get_weakop(op)) # NOTE: shape should be deterministic. annoying to cache with the ShapeTracker - # NOTE: we need "ret" to prevent the new buffer from being immediately deleted - if wop not in LazyBuffer.lazycache: LazyBuffer.lazycache[wop] = ret = super().__new__(cls) - else: ret = LazyBuffer.lazycache[wop] - return ret - - def __init__(self, device:str, shape:Union[ShapeTracker, Tuple[int, ...]], optype:OpType, op:LazyOp, dtype:DType): - if hasattr(self, 'device'): - return # cache hit, we return and don't reinit - self.st = shape if isinstance(shape, ShapeTracker) else ShapeTracker(tuple(shape)) - self.shape, self.optype, self.dtype = self.st.shape, optype, dtype + def __init__(self, device:str, st:ShapeTracker, optype:OpType, op:LazyOp, dtype:DType): + self.st = st # NOTE: this is not a copy! this should be a "read-only" ShapeTracker + self.device, self.shape, self.optype, self.dtype = device, self.st.shape, optype, dtype self.op: LazyOp = op self.realized: Optional[RawBuffer] = None - self.output_buffer: Optional[RawBuffer] = None - self.device, self.dbuffer = device, Device[device] + self.output_buffer: Optional[RawBuffer] = None # TODO: do we really need this? or can we just use realized # TODO: does children have to be a ref count instead of a set? can a Buffer be a double child? self.children: weakref.WeakSet[LazyBuffer] = weakref.WeakSet() # NOTE: op should be read only after construction of LazyBuffer for x in get_buffers(op): x.children.add(self) if not LAZY: self.realize() + # log phantom ops to the graph + from tinygrad.graph import log_op, GRAPH + if GRAPH >= 2: log_op(self, self.op, phantom=True) + def __repr__(self): return f"" def realize(self:LazyBuffer) -> LazyBuffer: @@ -155,11 +162,11 @@ class LazyBuffer: # NOTE: we have to make a copy of the numpy array here in case the user changes it. expose this? LazyNumpyArray doesn't have this problem @staticmethod def fromCPU(x:LazyNumpyArray, device) -> LazyBuffer: - return LazyBuffer(device, x.shape, LoadOps, LazyOp(LoadOps.FROMCPU, tuple(), x.copy()), dtypes.from_np(x.dtype)) + return create_lazybuffer(device, x.shape, LoadOps, LazyOp(LoadOps.FROMCPU, tuple(), x.copy()), dtypes.from_np(x.dtype)) # create a constant with the shape and dtype of self def const_like(self, val) -> LazyBuffer: - return LazyBuffer(self.device, (1,), LoadOps, LazyOp(LoadOps.FROMCPU, tuple(), LazyNumpyArray([val], (1,), self.dtype.np)), self.dtype) \ + return create_lazybuffer(self.device, (1,), LoadOps, LazyOp(LoadOps.FROMCPU, tuple(), LazyNumpyArray([val], (1,), self.dtype.np)), self.dtype) \ .movement_op(MovementOps.RESHAPE, (1,)*len(self.shape)).movement_op(MovementOps.EXPAND, self.shape) # NOTE: we also have to copy the numpy array on the way out...otherwise the underlying Tensor could be freed and use after free. improve this? @@ -173,12 +180,13 @@ class LazyBuffer: def binary_op(self:LazyBuffer, op:BinaryOps, y:LazyBuffer) -> LazyBuffer: return elementwise_op(op, self, y) def contiguous(self:LazyBuffer) -> LazyBuffer: if self.realized is None and self.op.op == LoadOps.CONTIGUOUS: return self # two CONTIGUOUS in a row is one - return LazyBuffer(self.device, self.shape, LoadOps, LazyOp(LoadOps.CONTIGUOUS, (self,)), self.dtype) + return create_lazybuffer(self.device, self.shape, LoadOps, LazyOp(LoadOps.CONTIGUOUS, (self,)), self.dtype) def reduce_op(self:LazyBuffer, op:ReduceOps, new_shape:Tuple[int, ...]) -> LazyBuffer: if self.shape == tuple(new_shape): return self - return LazyBuffer(self.device, new_shape, ReduceOps, LazyOp(op, (self,), new_shape), self.dtype) + return create_lazybuffer(self.device, new_shape, ReduceOps, LazyOp(op, (self,), new_shape), self.dtype) + # shrink -> stride -> permute -> reshape -> pad -> expand def movement_op(self:LazyBuffer, op:MovementOps, arg:Tuple[Any, ...]) -> LazyBuffer: # very instant nop if op == MovementOps.RESHAPE and self.shape == arg: return self @@ -224,11 +232,11 @@ class LazyBuffer: .movement_op(MovementOps.RESHAPE, ShapeTracker(self.st).movement_op(op, arg).shape) # if this MovementOp is being applied to a BinaryOp, apply the MovementOp to all the BinaryOp inputs instead. NOTE: UnaryOps is never an OpType - if SHUFFLE_MOVEMENT_OPS and self.optype == BinaryOps and self.realized is None and len(self.children) == 0 and op != MovementOps.EXPAND and (op != MovementOps.PAD or (SHUFFLE_PAD_OPS and all(x.op != BinaryOps.DIV for x in get_lazyops(self.op)))): - return replace_with_movement_op(self.op, op, arg) + if SHUFFLE_MOVEMENT_OPS and self.optype == BinaryOps and self.realized is None and (op in [MovementOps.SHRINK, MovementOps.STRIDE, MovementOps.PERMUTE] or (op == MovementOps.RESHAPE and self.op.op in UnaryOps)) and len(self.children) == 0: # and op != MovementOps.EXPAND and (op != MovementOps.PAD or (SHUFFLE_PAD_OPS and all(x.op != BinaryOps.DIV for x in get_lazyops(self.op)))): + return replace_with_movement_ops(self.op, [(op, arg)]) # create the buffer - ret = LazyBuffer(self.device, ShapeTracker(self.st).movement_op(op, arg), MovementOps, LazyOp(op, (self,), arg), self.dtype) + ret = create_lazybuffer(self.device, ShapeTracker(self.st).movement_op(op, arg), MovementOps, LazyOp(op, (self,), arg), self.dtype) # if the ShapeTracker becomes contiguous, replace the whole thing with a reshape (or nothing if shapes match) # NOTE: if ret is in the cache, it can already be realized @@ -243,6 +251,26 @@ class LazyBuffer: def elementwise_op(op:Union[UnaryOps, BinaryOps], *srcs:LazyBuffer, arg:Optional[Any]=None) -> LazyBuffer: out_device, out_shape, out_dtype = srcs[0].device, srcs[0].shape, max(x.dtype for x in srcs) if op != UnaryOps.CAST else cast(DType, arg) + # if we are separated from other binary ops by movement ops, we push those movement ops above those binaryops + if SHUFFLE_MOVEMENT_OPS: + new_srcs = [] + did_replace = False + for x in srcs: + mops: List[Tuple[MovementOps, Tuple[Any, ...]]] = [] + bx = x + # backwalk all the movement ops. don't push PAD or EXPAND + while bx.realized is None and bx.optype == MovementOps and bx.op.op != MovementOps.EXPAND and (bx.op.op != MovementOps.PAD or SHUFFLE_PAD_OPS) and len(bx.children) <= 1: + assert isinstance(bx.op.op, MovementOps) + mops.append((bx.op.op, bx.op.arg)) + bx = bx.op.src[0] + # NOTE: can't push pads with a div + if bx.realized is None and bx.optype == BinaryOps and len(bx.children) <= 1 and len(mops) and (all(x[0] != MovementOps.PAD for x in mops) or all(x.op != BinaryOps.DIV for x in get_lazyops(bx.op))): + new_srcs.append(replace_with_movement_ops(bx.op, mops[::-1])) + did_replace = True + else: + new_srcs.append(x) + if did_replace: return elementwise_op(op, *new_srcs, arg=arg) + # push all contiguous to the end of BinaryOps. kernels 198 -> 196 if PUSH_CONTIGUOUS and any(x.realized is None and x.op.op == LoadOps.CONTIGUOUS and len(x.op.src[0].children) <= 1 for x in srcs): new_srcs = [] @@ -254,8 +282,8 @@ def elementwise_op(op:Union[UnaryOps, BinaryOps], *srcs:LazyBuffer, arg:Optional new_srcs.append(x) return elementwise_op(op, *new_srcs, arg=arg).contiguous() - if MERGE_ELEMENTWISE_OPS or (MERGE_UNARY_OPS and len(set(srcs)) == 1): + if MERGE_ELEMENTWISE_OPS: # remove the buffers from any (childless) BinaryOps that feed into this srcs = tuple(x.op if x.optype == BinaryOps and len(x.children) == 0 and x.realized is None else x for x in srcs) # type: ignore - return LazyBuffer(out_device, out_shape, BinaryOps, LazyOp(op, srcs, arg), out_dtype) + return create_lazybuffer(out_device, out_shape, BinaryOps, LazyOp(op, srcs, arg), out_dtype) diff --git a/tinygrad/nn/__init__.py b/tinygrad/nn/__init__.py index 88e82f550f..dd68ac225d 100644 --- a/tinygrad/nn/__init__.py +++ b/tinygrad/nn/__init__.py @@ -21,7 +21,6 @@ class BatchNorm2d: y = (x_detached - batch_mean.reshape(shape=[1, -1, 1, 1])) batch_var = (y*y).mean(axis=(0,2,3)) batch_invstd = batch_var.add(self.eps).pow(-0.5) - self.batch_invstd = None # NOTE: wow, this is done all throughout training in most PyTorch models if self.track_running_stats: @@ -29,11 +28,9 @@ class BatchNorm2d: self.running_var.assign((1 - self.momentum) * self.running_var + self.momentum * batch_var) self.num_batches_tracked += 1 else: - batch_mean, batch_var = self.running_mean, self.running_var - # NOTE: this can be precomputed for static inference. if you manually update running_var, you have to reset this - if not hasattr(self, "batch_invstd") or not self.batch_invstd: - self.batch_invstd = batch_var.add(self.eps).pow(-0.5) - batch_invstd = self.batch_invstd + batch_mean = self.running_mean + # NOTE: this can be precomputed for static inference. we expand it here so it fuses + batch_invstd = self.running_var.reshape(1, -1, 1, 1).expand(x.shape).add(self.eps).rsqrt() return x.batchnorm(self.weight, self.bias, batch_mean, batch_invstd) diff --git a/tinygrad/ops.py b/tinygrad/ops.py index 0e21e461c4..9f7853a3e1 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -49,7 +49,7 @@ class Interpreted: if not created_context and ast in context: return context[ast] srcs = [self.exec_ast(x, context=context) if isinstance(x, LazyOp) else self.from_lazybuffer(x) for x in ast.src] ret = self.buffer(self.fxn_for_op[ast.op](*([self.to_underlying(x) for x in srcs] + ([ast.arg] if ast.arg is not None else [])))) - if DEBUG >= 4: print(f"*** {'exec' if created_context else ' '} {GlobalCounters.mem_used/1e9:5.2f} GB op: {ast.op:20s} out({ret.dtype.name}): {str(ret._buf.shape):30s} in({len(srcs)}):", list(set(x._buf.shape for x in srcs)), ast.arg if ast.arg is not None else "") + if DEBUG >= 3: print(f"*** {'exec' if created_context else ' '} {GlobalCounters.mem_used/1e9:5.2f} GB op: {ast.op:20s} out({ret.dtype.name}): {str(ret._buf.shape):30s} in({len(srcs)}):", list(set(x._buf.shape for x in srcs)), ast.arg if ast.arg is not None else "") if not created_context: context[ast] = ret if output is not None and output.output_buffer is not None: assert output.output_buffer.size == ret.size, output.output_buffer.dtype == ret.dtype @@ -65,7 +65,7 @@ class FlopCounter: return ret from tinygrad.shape.shapetracker import ShapeTracker shape_fxn_for_op: Dict[Op, Callable] = { - UnaryOps.CAST: lambda self,dtype: (self.shape, dtype, self.consume_flops() + prod(self.shape)), + UnaryOps.CAST: lambda self,dtype: (self.shape, dtype, self.consume_flops()), # cast uses no flops **{op:lambda self: (self.shape, self.dtype, self.consume_flops() + prod(self.shape)) for op in UnaryOps if op != UnaryOps.CAST}, **{op:lambda self,y: (self.shape, max(self.dtype, y.dtype), self.consume_flops() + y.consume_flops() + prod(self.shape)) for op in BinaryOps}, **{op:lambda self,new_shape: (new_shape, self.dtype, self.consume_flops() + prod(self.shape)) for op in ReduceOps}, diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 7f698510f5..54e0c4073c 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -263,6 +263,8 @@ class Tensor: # (padding_left, padding_right, padding_top, padding_bottom) def pad2d(self, padding:Union[List[int], Tuple[int, ...]]): return self.slice(((0,self.shape[0]), (0,self.shape[1]), (-padding[2],self.shape[2]+padding[3]), (-padding[0],self.shape[3]+padding[1]))) + @property + def T(self) -> Tensor: return self.transpose() def transpose(self, ax1=1, ax2=0) -> Tensor: order = list(range(len(self.shape))) order[ax1], order[ax2] = order[ax2], order[ax1] @@ -344,7 +346,7 @@ class Tensor: x = x.reshape(bs, groups, 1, cin, oy, ox, H, W).expand(bs, groups, rcout, cin, oy, ox, H, W).permute(0,1,2,4,5,3,6,7) # conv! broadcasted to (bs, groups, rcout, oy, ox, cin, H, W) - ret = (x * weight.reshape(1, groups, rcout, 1, 1, cin, H, W)).sum((-3, -2, -1)).reshape(bs, cout, oy, ox) + ret = (x * weight.reshape(1, groups, rcout, 1, 1, cin, H, W)).sum((-3, -2, -1), keepdim=True).reshape(bs, cout, oy, ox) return ret if bias is None else ret.add(bias.reshape(1, -1, 1, 1)) def dot(self, w:Tensor) -> Tensor: @@ -397,7 +399,7 @@ class Tensor: def sub(self, x:Union[Tensor, float], reverse=False) -> Tensor: return self._broadcasted(mlops.Sub, x, reverse) if isinstance(x, Tensor) or x != 0.0 or reverse else self def mul(self, x:Union[Tensor, float], reverse=False) -> Tensor: return self._broadcasted(mlops.Mul, x, reverse) if isinstance(x, Tensor) or x != 1.0 else self def pow(self, x:Union[Tensor, float], reverse=False) -> Tensor: return self._broadcasted(mlops.Pow, x, reverse) if isinstance(x, Tensor) or x != 1.0 or reverse else self - def div(self, x:Union[Tensor, float], reverse=False) -> Tensor: return self._broadcasted(mlops.Div, x, reverse) if isinstance(x, Tensor) or x != 1.0 or reverse else self + def div(self, x:Union[Tensor, float], reverse=False) -> Tensor: return self._broadcasted(mlops.Div, x, reverse) if isinstance(x, Tensor) or reverse else self.mul(1/x) def matmul(self, x:Tensor, reverse=False) -> Tensor: return x.dot(self) if reverse else self.dot(x) def maximum(self, x:Union[Tensor, float]) -> Tensor: return self._broadcasted(mlops.Maximum, x) @@ -443,12 +445,12 @@ class Tensor: def sequential(self, ll:List[Callable[[Tensor], Tensor]]): return functools.reduce(lambda x,f: f(x), ll, self) def layernorm(self, axis=-1, eps:float=1e-5) -> Tensor: - y = (self - self.mean(axis=axis, keepdim=True)) - return y.div((y*y).mean(axis=axis, keepdim=True).add(eps).sqrt()) + y = (self - self.mean(axis, keepdim=True)) + return y.mul((y*y).mean(axis, keepdim=True).add(eps).rsqrt()) def batchnorm(self, weight:Tensor, bias:Tensor, mean:Tensor, invstd:Tensor) -> Tensor: x = (self - mean.reshape(shape=[1, -1, 1, 1])) * weight.reshape(shape=[1, -1, 1, 1]) - return x.mul(invstd.reshape(shape=[1, -1, 1, 1])) + bias.reshape(shape=[1, -1, 1, 1]) + return x.mul(invstd.reshape(shape=[1, -1, 1, 1]) if len(invstd.shape) == 1 else invstd) + bias.reshape(shape=[1, -1, 1, 1]) def dropout(self, p=0.5) -> Tensor: if not Tensor.training: return self