diff --git a/examples/handcode_opt.py b/examples/handcode_opt.py index 8d2e2833cd..2938f7072c 100644 --- a/examples/handcode_opt.py +++ b/examples/handcode_opt.py @@ -109,7 +109,7 @@ if __name__ == "__main__": choices = [] for lin, nm in lins: tm = time_linearizer(lin, rawbufs, allow_test_size=False, cnt=10, disable_cache=True) - ops = (prg:=lin.to_program()).op_estimate + ops = (prg:=lin.to_program()).estimates.ops gflops = sym_infer(ops, {k:k.min for k in lin.ast.variables()})*1e-9/tm choices.append((tm, gflops, lin, prg, nm)) diff --git a/test/test_arange.py b/test/test_arange.py index bed8652777..d8a215faab 100644 --- a/test/test_arange.py +++ b/test/test_arange.py @@ -22,7 +22,7 @@ class TestArange(unittest.TestCase): #print(p.src) ExecItem(CompiledRunner(p), [tt.lazydata.buffer]).run() np.testing.assert_equal(tt.numpy(), np.arange(N)) - return p.op_estimate + return p.estimates.ops def test_complexity(self, opts=None, limit=None): # add 1 to avoid divide by 0. arange is 0 flops now! diff --git a/test/test_uops_stats.py b/test/test_uops_stats.py index 10b12dbc6c..87cef32f61 100644 --- a/test/test_uops_stats.py +++ b/test/test_uops_stats.py @@ -2,7 +2,7 @@ import unittest from tinygrad import Tensor from tinygrad.helpers import getenv, GlobalCounters from tinygrad.engine.schedule import create_schedule -from tinygrad.engine.realize import lower_schedule_item +from tinygrad.engine.realize import lower_schedule_item, ProgramSpec from tinygrad.codegen.linearize import linearize_uop from tinygrad.ops import flops_mem, Ops, UOp from tinygrad.dtype import dtypes @@ -13,7 +13,7 @@ from tinygrad.codegen.kernel import Kernel, Opt, OptOps, KernelOptError def get_stats(x:Tensor): si = create_schedule([x.lazydata])[-1] ei = lower_schedule_item(si) - return ei.prg.op_estimate, ei.prg.mem_estimate + return ei.prg.estimates.ops, ei.prg.estimates.mem class TestMemoryCount(unittest.TestCase): def test_add(self): @@ -148,17 +148,17 @@ class TestStatsOptimized(unittest.TestCase): cls.ast_gemm = (Tensor.empty(N, N) @ Tensor.empty(N, N)).schedule()[-1].ast cls.ast_reduce = (Tensor.empty(N*N).sum()).schedule()[-1].ast - def check_gemm(self, p, extra_flops=0): + def check_gemm(self, p:ProgramSpec, extra_flops=0): #p.uops.print() #print(p.src) - print(p.name, p.op_estimate, p.mem_estimate, p.lds_estimate) - self.assertEqual(p.op_estimate, 2*N*N*N + extra_flops) # N**3 mulaccs - self.assertEqual(p.mem_estimate, 3*N*N*4) # 3 NxN mats with floats + print(p.name, p.estimates.ops, p.estimates.mem, p.estimates.lds) + self.assertEqual(p.estimates.ops, 2*N*N*N + extra_flops) # N**3 mulaccs + self.assertEqual(p.estimates.mem, 3*N*N*4) # 3 NxN mats with floats def test_gemm(self): p = Kernel(self.ast_gemm).to_program() self.check_gemm(p) - self.assertEqual(p.lds_estimate, 2*N*N*N*4 + 4*N*N) + self.assertEqual(p.estimates.lds, 2*N*N*N*4 + 4*N*N) # this is a good lesson about why UPCASTing is a good idea @@ -167,7 +167,7 @@ class TestStatsOptimized(unittest.TestCase): k.apply_opt(Opt(OptOps.UPCAST, 0, 4)) p = k.to_program() self.check_gemm(p) - self.assertEqual(p.lds_estimate, N*N*N*4 + N*N*N*4//4 + 4*N*N) + self.assertEqual(p.estimates.lds, N*N*N*4 + N*N*N*4//4 + 4*N*N) def test_gemm_upcasted(self): k = Kernel(self.ast_gemm) @@ -176,7 +176,7 @@ class TestStatsOptimized(unittest.TestCase): k.apply_opt(Opt(OptOps.UNROLL, 0, 4)) p = k.to_program() self.check_gemm(p) - self.assertEqual(p.lds_estimate, 2*N*N*N*4//4 + 4*N*N) + self.assertEqual(p.estimates.lds, 2*N*N*N*4//4 + 4*N*N) def test_gemm_upcasted_locals(self): k = Kernel(self.ast_gemm) @@ -189,7 +189,7 @@ class TestStatsOptimized(unittest.TestCase): raise unittest.SkipTest("no locals") p = k.to_program() self.check_gemm(p) - self.assertEqual(p.lds_estimate, 2*N*N*N*4//4 + 4*N*N) + self.assertEqual(p.estimates.lds, 2*N*N*N*4//4 + 4*N*N) def test_gemm_group(self): k = Kernel(self.ast_gemm) @@ -201,14 +201,14 @@ class TestStatsOptimized(unittest.TestCase): p = k.to_program() # NOTE: these are sort of wrong. they aren't honoring the IF statement self.check_gemm(p, extra_flops=SZ*4) - self.assertEqual(p.lds_estimate, 2*N*N*N*4 + SZ*4 + (SZ*4 + 4*N*N)*4) + self.assertEqual(p.estimates.lds, 2*N*N*N*4 + SZ*4 + (SZ*4 + 4*N*N)*4) def test_reduce(self): k = Kernel(self.ast_reduce) p = k.to_program() - print(p.name, p.op_estimate, p.mem_estimate, p.lds_estimate) - self.assertEqual(p.op_estimate, N*N) - self.assertEqual(p.mem_estimate, N*N*4 + 4) + print(p.name, p.estimates.ops, p.estimates.mem, p.estimates.lds) + self.assertEqual(p.estimates.ops, N*N) + self.assertEqual(p.estimates.mem, N*N*4 + 4) def test_reduce_group(self): k = Kernel(self.ast_reduce) @@ -218,7 +218,7 @@ class TestStatsOptimized(unittest.TestCase): raise unittest.SkipTest("no locals") p = k.to_program() # NOTE: these are wrong, they don't respect the if statement - print(p.name, p.op_estimate, p.mem_estimate, p.lds_estimate) + print(p.name, p.estimates.ops, p.estimates.mem, p.estimates.lds) if __name__ == '__main__': unittest.main(verbosity=2) diff --git a/tinygrad/engine/jit.py b/tinygrad/engine/jit.py index 378393910d..64dc070783 100644 --- a/tinygrad/engine/jit.py +++ b/tinygrad/engine/jit.py @@ -4,9 +4,9 @@ from tinygrad.tensor import Tensor from tinygrad.helpers import flatten, merge_dicts, DEBUG, Context, BEAM, getenv, colored, JIT, dedup, partition, unwrap from tinygrad.device import Buffer, Compiled, Device from tinygrad.dtype import DType -from tinygrad.ops import UOp, ssimplify, Variable, sint, sym_infer +from tinygrad.ops import UOp, Variable, sym_infer from tinygrad.shape.shapetracker import ShapeTracker -from tinygrad.engine.realize import ExecItem, capturing, EmptyOp, ViewOp, BufferCopy, BufferXfer, CompiledRunner, Runner +from tinygrad.engine.realize import ExecItem, capturing, EmptyOp, ViewOp, BufferCopy, BufferXfer, CompiledRunner, Runner, Estimates from tinygrad.engine.memory import _internal_memory_planner from tinygrad.nn.state import get_parameters from dataclasses import dataclass @@ -75,10 +75,6 @@ class GraphRunner(Runner): self.launch_dims_replace:Dict[int, Tuple[Optional[int], Optional[int]]] = {} self.launch_dims_base:Dict[int, Tuple[Tuple[int, ...], Tuple[int, ...]]] = {} - op_estimate: sint = 0 - mem_estimate: sint = 0 - lds_estimate: sint = 0 - def is_sym_dim(dim) -> bool: return not all(isinstance(d, (int, float)) for d in dim) self.vars = sorted(var_vals.keys(), key=lambda v: v.expr) @@ -86,10 +82,9 @@ class GraphRunner(Runner): [tuple(d) for ji in jit_cache if isinstance(ji.prg, CompiledRunner) and (d:=ji.prg.p.global_size) and is_sym_dim(d)]) def find_symbolic_dim(dim): return self.symbolic_dims.index(tuple(dim)) if dim is not None and tuple(dim) in self.symbolic_dims else None + estimates = Estimates() for j,ji in enumerate(jit_cache): - op_estimate += ji.prg.op_estimate - mem_estimate += ji.prg.mem_estimate - lds_estimate += ji.prg.lds_estimate + estimates += ji.prg.estimates if isinstance(ji.prg, CompiledRunner): if ji.prg.p.vars: self.var_vals_replace[j] = [self.vars.index(v) for v in ji.prg.p.vars] @@ -103,8 +98,7 @@ class GraphRunner(Runner): self.w_dependency_map: Dict[int, Any] = {} self.r_dependency_map: Dict[int, List[Any]] = collections.defaultdict(list) - super().__init__(colored(f"", "cyan"), jit_cache[0].prg.device.split(":")[0], - ssimplify(op_estimate), ssimplify(mem_estimate), ssimplify(lds_estimate)) + super().__init__(colored(f"", "cyan"), jit_cache[0].prg.device.split(":")[0], estimates.simplify()) def updated_vars(self, var_vals: Dict[Variable, int]): vals = [var_vals[v] for v in self.vars] diff --git a/tinygrad/engine/realize.py b/tinygrad/engine/realize.py index 6144714d55..07deaf3ed2 100644 --- a/tinygrad/engine/realize.py +++ b/tinygrad/engine/realize.py @@ -2,9 +2,9 @@ from typing import List, Dict, Optional, cast, Generator, Tuple import time, pprint from dataclasses import dataclass, replace from tinygrad.helpers import colored, getenv, DEBUG, GlobalCounters, ansilen, BEAM, NOOPT, all_int, CAPTURING, Metadata, TRACEMETA -from tinygrad.ops import Ops, UOp, Variable, sym_infer, sint +from tinygrad.ops import Ops, UOp, Variable, sym_infer from tinygrad.device import Device, Buffer -from tinygrad.renderer import Renderer, ProgramSpec +from tinygrad.renderer import Renderer, ProgramSpec, Estimates from tinygrad.codegen.kernel import Kernel from tinygrad.engine.schedule import ScheduleItem @@ -28,9 +28,8 @@ def get_kernel(renderer:Renderer, ast:UOp) -> Kernel: # **************** Runners **************** class Runner: - def __init__(self, display_name:str, device:str, op_estimate:sint=0, mem_estimate:sint=0, lds_estimate:Optional[sint]=None): - self.first_run, self.display_name, self.device, self.op_estimate, self.mem_estimate, self.lds_estimate = \ - True, display_name, device, op_estimate, mem_estimate, mem_estimate if lds_estimate is None else lds_estimate + def __init__(self, display_name:str, device:str, estimates=Estimates()): + self.first_run, self.display_name, self.device, self.estimates = True, display_name, device, estimates @property def dev(self): return Device[self.device] def exec(self, rawbufs:List[Buffer], var_vals:Optional[Dict[Variable, int]]=None) -> Optional[float]: @@ -45,7 +44,7 @@ class CompiledRunner(Runner): self.lib:bytes = precompiled if precompiled is not None else Device[p.device].compiler.compile_cached(p.src) if DEBUG >= 6: Device[p.device].compiler.disassemble(self.lib) self._prg = Device[p.device].runtime(p.function_name, self.lib) - super().__init__(p.name, p.device, p.op_estimate, p.mem_estimate, p.lds_estimate) + super().__init__(p.name, p.device, p.estimates) def __reduce__(self): return self.__class__, (self.p, self.lib) @@ -79,7 +78,7 @@ class BufferCopy(Runner): def __init__(self, total_sz, dest_device, src_device): if total_sz >= 1e6: name = f"{type(self).__name__[6:].lower()} {total_sz/1e6:7.2f}M, {dest_device[:7]:>7s} <- {src_device[:7]:7s}" else: name = f"{type(self).__name__[6:].lower()} {total_sz:8d}, {dest_device[:7]:>7s} <- {src_device[:7]:7s}" - super().__init__(colored(name, "yellow"), dest_device, 0, total_sz) + super().__init__(colored(name, "yellow"), dest_device, Estimates(mem=total_sz)) def copy(self, dest, src): disk_supports_fast_copyout = src.device.startswith("DISK") and hasattr(src.allocator.dev, 'io_uring') and \ getattr(src.allocator.dev, 'fd', None) is not None @@ -129,11 +128,11 @@ class ExecItem: et = self.prg(bufs, var_vals, wait=wait or DEBUG >= 2) if do_update_stats: GlobalCounters.kernel_count += 1 - GlobalCounters.global_ops += (op_est:=sym_infer(self.prg.op_estimate, var_vals)) - GlobalCounters.global_mem += (mem_est:=sym_infer(self.prg.mem_estimate, var_vals)) + GlobalCounters.global_ops += (op_est:=sym_infer(self.prg.estimates.ops, var_vals)) + GlobalCounters.global_mem += (mem_est:=sym_infer(self.prg.estimates.mem, var_vals)) if et is not None: GlobalCounters.time_sum_s += et if DEBUG >= 2: - lds_est = sym_infer(self.prg.lds_estimate, var_vals) + lds_est = sym_infer(self.prg.estimates.lds, var_vals) mem_est = min(mem_est, lds_est) # there can't be more memory accessed than loads/stores. remove this when symbolic is fixed ptm = (colored(f"{et*1e3:9.2f}ms", "yellow") if et > 0.01 else f"{et*1e6:9.2f}us") if et is not None else "" print(f"{colored(f'*** {self.prg.device[:7]:7s} {GlobalCounters.kernel_count:4d}', 'magenta' if jit else ('green' if self.prg.first_run else None))} {self.prg.display_name+' '*(41-ansilen(self.prg.display_name))} arg {len(bufs):2d} mem {GlobalCounters.mem_used/1e9:5.2f} GB " + # noqa: E501 diff --git a/tinygrad/engine/search.py b/tinygrad/engine/search.py index 42c488d813..cf5805c682 100644 --- a/tinygrad/engine/search.py +++ b/tinygrad/engine/search.py @@ -152,7 +152,7 @@ def beam_search(lin:Kernel, rawbufs:List[Buffer], amt:int, allow_test_size=True, p, lib, compile_et = proc if lib in seen_libs: continue # filter out kernels that use 1000x more compute than the smallest - least_compute_ops = min(this_compute_ops:=sym_infer(p.op_estimate, var_vals), least_compute_ops) + least_compute_ops = min(this_compute_ops:=sym_infer(p.estimates.ops, var_vals), least_compute_ops) if least_compute_ops*1000 < this_compute_ops: continue seen_libs.add(lib) try: tms = _time_program(p, lib, var_vals, rawbufs, early_stop=beam[0][1]*3 if len(beam) else 1.0, clear_l2=hasattr(dev, 'invalidate_caches')) diff --git a/tinygrad/ops.py b/tinygrad/ops.py index 429458e9d3..8242cd1fbc 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -687,18 +687,12 @@ def flops_mem(uops:List[UOp], ignore_indexing=False) -> Tuple[sint, sint]: if u.op is Ops.RANGE: mult_stack.append(mults) mults *= (u.src[1] - u.src[0]).ssimplify() - elif u.op is Ops.ENDRANGE: - mults = mult_stack.pop(-1) - elif u.op is Ops.SPECIAL: - mults *= u.arg[1] # NOTE: we don't push to the mult_stack here, you can't end these - elif u.op is Ops.LOAD: - mem += u.dtype.itemsize * mults - elif u.op is Ops.STORE: - mem += u.src[1].dtype.itemsize * mults - elif u.op in GroupOp.ALU and u not in dont_count: - flops += (mults * (2 if u.op is Ops.MULACC else 1)) * u.dtype.count - elif u.op is Ops.WMMA and u not in dont_count: - flops += 2 * prod(u.arg[1]) // u.arg[5] * mults + elif u.op is Ops.ENDRANGE: mults = mult_stack.pop(-1) + elif u.op is Ops.SPECIAL: mults *= u.arg[1] # NOTE: we don't push to the mult_stack here, you can't end these + elif u.op is Ops.LOAD: mem += u.dtype.itemsize * mults + elif u.op is Ops.STORE: mem += u.src[1].dtype.itemsize * mults + elif u.op in GroupOp.ALU and u not in dont_count: flops += (mults * (2 if u.op is Ops.MULACC else 1)) * u.dtype.count + elif u.op is Ops.WMMA and u not in dont_count: flops += 2 * prod(u.arg[1]) // u.arg[5] * mults return flops, mem # ***** pattern matcher ***** diff --git a/tinygrad/renderer/__init__.py b/tinygrad/renderer/__init__.py index d2cbdf4a95..db2f3d69ee 100644 --- a/tinygrad/renderer/__init__.py +++ b/tinygrad/renderer/__init__.py @@ -1,8 +1,9 @@ +from __future__ import annotations from typing import Optional, List, Tuple, Dict, Callable, Any import functools from dataclasses import dataclass, field from tinygrad.helpers import to_function_name, dedup, prod -from tinygrad.ops import Ops, UOp, flops_mem, sym_infer, sint, Variable +from tinygrad.ops import Ops, UOp, flops_mem, sym_infer, sint, Variable, ssimplify from tinygrad.dtype import DType @dataclass(frozen=True) @@ -22,6 +23,17 @@ class TensorCore: # D = A * B + C, A is (M x K), B is (K x N), C and D are (M x opts_seq: Tuple[str,str] = ("UP","LC") # upcast input, local the thread pattern def __str__(self): return "_".join(["WMMA"] + list(map(str, self.dims)) + [self.dtype_in.name, self.dtype_out.name]) +@dataclass(frozen=True) +class Estimates: + # number of FLOPS used in the Kernel + ops:sint = 0 + # bytes accessed in loads and stores + lds:sint = 0 + # total bytes accessed, counting only once for bytes that are accessed multiple times + mem:sint = 0 + def __add__(self, o:Estimates): return Estimates(self.ops + o.ops, self.lds + o.lds, self.mem + o.mem) + def simplify(self): return Estimates(ssimplify(self.ops), ssimplify(self.lds), ssimplify(self.mem)) + @dataclass class ProgramSpec: name:str @@ -55,12 +67,8 @@ class ProgramSpec: self.outs = sorted(dedup(self.outs)) self._ran_post_init = True - @property - def op_estimate(self) -> sint: return self._ops_lds[0] - @property - def lds_estimate(self) -> sint: return self._ops_lds[1] @functools.cached_property - def _ops_lds(self) -> Tuple[sint, sint]: return (0,0) if self.uops is None else flops_mem(self.uops, ignore_indexing=True) + def estimates(self) -> Estimates: return Estimates(*((0,0) if self.uops is None else flops_mem(self.uops, ignore_indexing=True)), self.mem_estimate) @functools.cached_property def function_name(self) -> str: return to_function_name(self.name)