diff --git a/tinygrad/ops.py b/tinygrad/ops.py index e715bec423..4e84f6f237 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -1,5 +1,5 @@ from __future__ import annotations -import importlib, inspect, functools, pathlib, time +import importlib, inspect, functools, pathlib, time, re from enum import Enum, auto from typing import TYPE_CHECKING, Union, Type, Tuple, Any, List, Optional, Dict, Callable, Mapping, cast from tinygrad.helpers import ansilen, prod, DEBUG, getenv, GlobalCounters, DType, colored, BEAM, NOOPT, dedup, all_int @@ -134,6 +134,19 @@ def get_lazyop_info(ast:LazyOp) -> FlopCounter: def run_ast(ast): return InterpretedFlopCounter[ast.op](*([run_ast(x) for x in ast.src]+([ast.arg] if ast.arg is not None else []))) return run_ast(ast) +# **************** GlobalCounters stats **************** + +def update_stats(name, op_estimate, mem_estimate, var_vals: Optional[Dict[Variable, int]], et: Optional[float], buf_count, jit=False, num_kernels=1, lra=None): + if var_vals is None: var_vals = {} + op_estimate, mem_estimate = sym_infer(op_estimate, var_vals), sym_infer(mem_estimate, var_vals) + if DEBUG >= 2: + print(f"{colored(f'*** {GlobalCounters.kernel_count:4d}', ('magenta' if num_kernels == 1 else 'CYAN') if jit else None)} {name+' '*(37-ansilen(name))} arg {buf_count:3d} sz {str(lra.get('global_size', '') if lra else ''):18s} {str(lra.get('local_size', '') if lra else ''):12s} OPs {int(op_estimate/1e6):6d}M/{GlobalCounters.global_ops/1e9:7.2f}G mem {GlobalCounters.mem_used/1e9:5.2f} GB " + + (str() if et is None else f"tm {et*1e6:9.2f}us/{GlobalCounters.time_sum_s*1e3:9.2f}ms ({op_estimate/((et or 1e-20)*1e9):8.2f} GFLOPS, {mem_estimate/((et or 1e-20)*1e9):7.2f} GB/s)")) + GlobalCounters.kernel_count += num_kernels + GlobalCounters.global_ops += op_estimate + GlobalCounters.global_mem += mem_estimate + if et is not None: GlobalCounters.time_sum_s += et + # **************** batch executor **************** @dataclass(frozen=True) @@ -161,18 +174,6 @@ class BatchExecutor: for ji in self.jit_cache: ji.prg(ji.rawbufs, var_vals, jit=True) self.clear_jit_inputs() - def update_stats(self, var_vals: Dict[Variable, int], et: Optional[float]): - # TODO: this is mostly copied from ASTRunner - op_estimate = sym_infer(self.op_estimate, var_vals) - mem_estimate = sym_infer(self.mem_estimate, var_vals) - if DEBUG >= 2: - print(f"{colored(f'*** {GlobalCounters.kernel_count:4d}', 'CYAN')} kernels:{len(self.jit_cache):4d} inputs:{len(self.input_replace):3d} {' '.join([f'{k.expr}={v}' for k,v in var_vals.items()])[:50]:50s} OPs {int(op_estimate/1e6):6d}M/{GlobalCounters.global_ops/1e9:7.2f}G mem {GlobalCounters.mem_used/1e9:5.2f} GB " + - (str() if et is None else f"tm {et*1e6:9.2f}us/{GlobalCounters.time_sum_s*1e3:9.2f}ms ({op_estimate/((et or 1e-20)*1e9):8.2f} GFLOPS, {mem_estimate/((et or 1e-20)*1e9):7.2f} GB/s)")) - GlobalCounters.kernel_count += len(self.jit_cache) - GlobalCounters.global_ops += sym_infer(self.op_estimate, var_vals) - GlobalCounters.global_mem += sym_infer(self.mem_estimate, var_vals) - if et is not None: GlobalCounters.time_sum_s += et - def clear_jit_inputs(self): for (j,i) in self.input_replace.keys(): self.jit_cache[j].rawbufs[i] = None @@ -193,17 +194,6 @@ class ASTRunner: CacheCollector.add(self, rawbufs, var_vals if var_vals is not None else {}) return et - def update_stats(self, name, var_vals: Optional[Dict[Variable, int]], et: Optional[float], buf_count, lra, jit): - if var_vals is None: var_vals = {} - op_estimate = sym_infer(self.op_estimate, var_vals) - mem_estimate = sym_infer(self.mem_estimate, var_vals) - if DEBUG >= 2: - print(f"{colored(f'*** {GlobalCounters.kernel_count:4d}', 'magenta' if jit else None)} {name+' '*(37-ansilen(name))} arg {buf_count:3d} sz {str(lra.get('global_size', '')):18s} {str(lra.get('local_size', '')):12s} OPs {int(op_estimate/1e6):6d}M/{GlobalCounters.global_ops/1e9:7.2f}G mem {GlobalCounters.mem_used/1e9:5.2f} GB " + - (str() if et is None else f"tm {et*1e6:9.2f}us/{GlobalCounters.time_sum_s*1e3:9.2f}ms ({op_estimate/((et or 1e-20)*1e9):8.2f} GFLOPS, {mem_estimate/((et or 1e-20)*1e9):7.2f} GB/s)")) - GlobalCounters.kernel_count += 1 - GlobalCounters.global_ops += op_estimate - GlobalCounters.global_mem += mem_estimate - def __call__(self, rawbufs:List[Optional[RawBuffer]], var_vals:Optional[Dict[Variable, int]]=None, jit=False, force_wait=False) -> Optional[float]: raise NotImplementedError("override this") @@ -218,7 +208,7 @@ class InterpretedASTRunner(ASTRunner): st = time.perf_counter() ret: RawBuffer = self.fxn(rawbufs[1:], var_vals) et = time.perf_counter() - st - self.update_stats(f"", var_vals, et, len(rawbufs), {}, jit) + update_stats(f"", self.op_estimate, self.mem_estimate, var_vals, et, len(rawbufs), jit) if rawbufs[0] is not None: assert rawbufs[0].dtype == ret.dtype rawbufs[0].size = ret.size # NOTE: for symbolic this can change @@ -226,7 +216,6 @@ class InterpretedASTRunner(ASTRunner): else: rawbufs[0] = ret return et -from tinygrad.runtime.interpreted import interpret_ast class Interpreted: def __init__(self, buffer: Type[RawBuffer], fxn_for_op:Dict[Op, Callable], from_underlying:Optional[Callable]=None): self.buffer, self.fxn_for_op, self.from_underlying = buffer, fxn_for_op, from_underlying @@ -236,11 +225,50 @@ class Interpreted: self.method_cache: Dict[LazyOp, InterpretedASTRunner] = {} def exec_ast(self, ast:LazyOp, output:LazyBuffer, inputs:Tuple[LazyBuffer, ...], var_vals:Dict[Variable, int], **kwargs): - if ast not in self.method_cache: self.method_cache[ast] = InterpretedASTRunner(ast, interpret_ast(self.fxn_for_op, self.from_underlying, ast)) + if ast not in self.method_cache: self.method_cache[ast] = get_interpreted_fxn(self.fxn_for_op, self.from_underlying, ast) rawbufs = [output.realized if output.realized is not None else output.output_buffer] + [x.realized for x in inputs] self.method_cache[ast].exec(rawbufs, var_vals) output.realized = rawbufs[0] +def get_interpreted_fxn(fxn_for_op:Dict[Op, Callable], from_underlying:Optional[Callable], ast:LazyOp) -> InterpretedASTRunner: + if DEBUG >= 3: + from tinygrad.graph import print_tree + print_tree(ast) + tglob: Dict[str, Any] = {"Variable": Variable} + lines: List[str] = [] + + @functools.lru_cache(None) + def gstr(x:Any, nm=None) -> str: + if ('Variable' in (str_arg := repr(x)) or 'NumNode' in str_arg): + str_arg = re.sub(r'Variable\(.*?\)', lambda m: f'var_vals[{str(m.group(0))}]', str_arg) + # TODO: (Variable - Variable) might create NumNode. can we remove it? + return re.sub(r'NumNode\((.*?)\)', r'\1', str_arg) + ret = str(nm).replace(".", "_") if nm else f"m{len(tglob):04d}" + tglob[ret] = x + return ret + + @functools.lru_cache(None) + def _interpret_ast(ast:LazyOp) -> str: + if TernaryOps.MULACC in fxn_for_op and ast.op == ReduceOps.SUM and isinstance(ast.src[0], LazyOp) and ast.src[0].op == BinaryOps.MUL: + ast = LazyOp(TernaryOps.MULACC, ast.src[0].src, ast.arg) + + if ast.op in BufferOps: + tmp = f"{gstr(fxn_for_op[ast.op], ast.op)}({gstr(ast.arg.val)}, {gstr(ast.arg.dtype)})" if ast.op == BufferOps.CONST else f"{gstr(fxn_for_op[ast.op], ast.op)}(inputs[{ast.arg.idx-1}])" + for mop,arg in ast.arg.st.to_movement_ops(): tmp = f"{gstr(fxn_for_op[mop], mop)}({tmp}, {gstr(arg)})" + else: + inp = [_interpret_ast(src) for src in ast.src] + tmp = f"{gstr(fxn_for_op[ast.op], ast.op)}({', '.join(inp + ([gstr(ast.arg)] if ast.arg else []))})" + + ret = f"a{len(lines)}" + lines.append(f" {ret} = {tmp}") + return ret + + ret = _interpret_ast(ast) + src = '\n'.join(['def run(inputs, var_vals):'] + lines + [f" return {gstr(from_underlying, 'from_underlying')}({ret})" if from_underlying is not None else f" return {ret}"]) + if DEBUG >= 4: print(functools.reduce(lambda x,y: (x.replace(y[0], str(y[1])) if y[0][0:2] == "m0" else x), tglob.items(), src)) + exec(compile(src, "", "exec"), tglob) # pylint: disable=exec-used + return InterpretedASTRunner(ast, tglob['run']) + # **************** for Compiled Buffers **************** class CompiledASTRunner(ASTRunner): @@ -272,8 +300,8 @@ class CompiledASTRunner(ASTRunner): lra = self.runtime_args.copy() if global_size: lra['global_size'] = global_size if local_size and 'local_size' not in lra: lra['local_size'] = local_size - if et := self.clprg(*rawbufs, *var_vals.values(), **lra, wait=force_wait or DEBUG>=2): GlobalCounters.time_sum_s += et - self.update_stats(self.display_name if self.display_name is not None else self.name, var_vals, et, len(rawbufs), lra, jit) + et = self.clprg(*rawbufs, *var_vals.values(), **lra, wait=force_wait or DEBUG>=2) + update_stats(self.display_name if self.display_name is not None else self.name, self.op_estimate, self.mem_estimate, var_vals, et, len(rawbufs), jit, lra=lra) return et class Compiled: diff --git a/tinygrad/runtime/interpreted.py b/tinygrad/runtime/interpreted.py deleted file mode 100644 index 06fda8e734..0000000000 --- a/tinygrad/runtime/interpreted.py +++ /dev/null @@ -1,45 +0,0 @@ -from typing import Callable, Optional, Dict, List, Any -import functools, re -from tinygrad.helpers import DEBUG -from tinygrad.ops import LazyOp, TernaryOps, ReduceOps, BinaryOps, BufferOps, Op -from tinygrad.shape.symbolic import Variable -from tinygrad.runtime.lib import RawBuffer - -def interpret_ast(fxn_for_op:Dict[Op, Callable], from_underlying:Optional[Callable], ast:LazyOp) -> Callable[[List[RawBuffer], Dict[Variable, int]], RawBuffer]: - if DEBUG >= 3: - from tinygrad.graph import print_tree - print_tree(ast) - tglob: Dict[str, Any] = {"Variable": Variable} - lines: List[str] = [] - - @functools.lru_cache(None) - def gstr(x:Any, nm=None) -> str: - if ('Variable' in (str_arg := repr(x)) or 'NumNode' in str_arg): - str_arg = re.sub(r'Variable\(.*?\)', lambda m: f'var_vals[{str(m.group(0))}]', str_arg) - # TODO: (Variable - Variable) might create NumNode. can we remove it? - return re.sub(r'NumNode\((.*?)\)', r'\1', str_arg) - ret = str(nm).replace(".", "_") if nm else f"m{len(tglob):04d}" - tglob[ret] = x - return ret - - @functools.lru_cache(None) - def _interpret_ast(ast:LazyOp) -> str: - if TernaryOps.MULACC in fxn_for_op and ast.op == ReduceOps.SUM and isinstance(ast.src[0], LazyOp) and ast.src[0].op == BinaryOps.MUL: - ast = LazyOp(TernaryOps.MULACC, ast.src[0].src, ast.arg) - - if ast.op in BufferOps: - tmp = f"{gstr(fxn_for_op[ast.op], ast.op)}({gstr(ast.arg.val)}, {gstr(ast.arg.dtype)})" if ast.op == BufferOps.CONST else f"{gstr(fxn_for_op[ast.op], ast.op)}(inputs[{ast.arg.idx-1}])" - for mop,arg in ast.arg.st.to_movement_ops(): tmp = f"{gstr(fxn_for_op[mop], mop)}({tmp}, {gstr(arg)})" - else: - inp = [_interpret_ast(src) for src in ast.src] - tmp = f"{gstr(fxn_for_op[ast.op], ast.op)}({', '.join(inp + ([gstr(ast.arg)] if ast.arg else []))})" - - ret = f"a{len(lines)}" - lines.append(f" {ret} = {tmp}") - return ret - - ret = _interpret_ast(ast) - src = '\n'.join(['def run(inputs, var_vals):'] + lines + [f" return {gstr(from_underlying, 'from_underlying')}({ret})" if from_underlying is not None else f" return {ret}"]) - if DEBUG >= 4: print(functools.reduce(lambda x,y: (x.replace(y[0], str(y[1])) if y[0][0:2] == "m0" else x), tglob.items(), src)) - exec(compile(src, "", "exec"), tglob) # pylint: disable=exec-used - return tglob['run'] \ No newline at end of file diff --git a/tinygrad/runtime/ops_metal.py b/tinygrad/runtime/ops_metal.py index bca9972ba2..4e07079c54 100644 --- a/tinygrad/runtime/ops_metal.py +++ b/tinygrad/runtime/ops_metal.py @@ -4,7 +4,7 @@ import Metal, Cocoa, libdispatch from typing import List, Any, Tuple, Dict, Union, Set, cast from tinygrad.codegen.kernel import LinearizerOptions from tinygrad.helpers import prod, getenv, DEBUG, DType, dtypes, diskcache, dedup -from tinygrad.ops import Compiled, BatchExecutor, JitItem, CompiledASTRunner +from tinygrad.ops import Compiled, BatchExecutor, JitItem, CompiledASTRunner, update_stats from tinygrad.renderer.metal import MetalRenderer from tinygrad.runtime.lib import RawBufferMapped, RawBuffer, LRUAllocator from tinygrad.shape.symbolic import Variable, Node @@ -149,7 +149,7 @@ class MetalBatchExecutor(BatchExecutor): else: METAL.mtl_buffers_in_flight.append(command_buffer) et = None - super().update_stats(var_vals, et) + update_stats(f"", self.op_estimate, self.mem_estimate, var_vals, et, buf_count=len(input_rawbuffers), jit=True, num_kernels=len(self.jit_cache)) return et MetalBuffer = Compiled(RawMetalBuffer, LinearizerOptions(device="METAL"), MetalRenderer, compile_metal, MetalProgram, METAL.synchronize, batch_executor=MetalBatchExecutor if METAL.supports_icb else BatchExecutor)