JIT cleanups (#2317)

* cleanup cleanup

* dedup update_stats
This commit is contained in:
George Hotz
2023-11-15 13:34:52 -08:00
committed by GitHub
parent b64738e1d6
commit 628365eab6
3 changed files with 59 additions and 76 deletions

View File

@@ -1,5 +1,5 @@
from __future__ import annotations
import importlib, inspect, functools, pathlib, time
import importlib, inspect, functools, pathlib, time, re
from enum import Enum, auto
from typing import TYPE_CHECKING, Union, Type, Tuple, Any, List, Optional, Dict, Callable, Mapping, cast
from tinygrad.helpers import ansilen, prod, DEBUG, getenv, GlobalCounters, DType, colored, BEAM, NOOPT, dedup, all_int
@@ -134,6 +134,19 @@ def get_lazyop_info(ast:LazyOp) -> FlopCounter:
def run_ast(ast): return InterpretedFlopCounter[ast.op](*([run_ast(x) for x in ast.src]+([ast.arg] if ast.arg is not None else [])))
return run_ast(ast)
# **************** GlobalCounters stats ****************
def update_stats(name, op_estimate, mem_estimate, var_vals: Optional[Dict[Variable, int]], et: Optional[float], buf_count, jit=False, num_kernels=1, lra=None):
if var_vals is None: var_vals = {}
op_estimate, mem_estimate = sym_infer(op_estimate, var_vals), sym_infer(mem_estimate, var_vals)
if DEBUG >= 2:
print(f"{colored(f'*** {GlobalCounters.kernel_count:4d}', ('magenta' if num_kernels == 1 else 'CYAN') if jit else None)} {name+' '*(37-ansilen(name))} arg {buf_count:3d} sz {str(lra.get('global_size', '') if lra else ''):18s} {str(lra.get('local_size', '') if lra else ''):12s} OPs {int(op_estimate/1e6):6d}M/{GlobalCounters.global_ops/1e9:7.2f}G mem {GlobalCounters.mem_used/1e9:5.2f} GB " +
(str() if et is None else f"tm {et*1e6:9.2f}us/{GlobalCounters.time_sum_s*1e3:9.2f}ms ({op_estimate/((et or 1e-20)*1e9):8.2f} GFLOPS, {mem_estimate/((et or 1e-20)*1e9):7.2f} GB/s)"))
GlobalCounters.kernel_count += num_kernels
GlobalCounters.global_ops += op_estimate
GlobalCounters.global_mem += mem_estimate
if et is not None: GlobalCounters.time_sum_s += et
# **************** batch executor ****************
@dataclass(frozen=True)
@@ -161,18 +174,6 @@ class BatchExecutor:
for ji in self.jit_cache: ji.prg(ji.rawbufs, var_vals, jit=True)
self.clear_jit_inputs()
def update_stats(self, var_vals: Dict[Variable, int], et: Optional[float]):
# TODO: this is mostly copied from ASTRunner
op_estimate = sym_infer(self.op_estimate, var_vals)
mem_estimate = sym_infer(self.mem_estimate, var_vals)
if DEBUG >= 2:
print(f"{colored(f'*** {GlobalCounters.kernel_count:4d}', 'CYAN')} kernels:{len(self.jit_cache):4d} inputs:{len(self.input_replace):3d} {' '.join([f'{k.expr}={v}' for k,v in var_vals.items()])[:50]:50s} OPs {int(op_estimate/1e6):6d}M/{GlobalCounters.global_ops/1e9:7.2f}G mem {GlobalCounters.mem_used/1e9:5.2f} GB " +
(str() if et is None else f"tm {et*1e6:9.2f}us/{GlobalCounters.time_sum_s*1e3:9.2f}ms ({op_estimate/((et or 1e-20)*1e9):8.2f} GFLOPS, {mem_estimate/((et or 1e-20)*1e9):7.2f} GB/s)"))
GlobalCounters.kernel_count += len(self.jit_cache)
GlobalCounters.global_ops += sym_infer(self.op_estimate, var_vals)
GlobalCounters.global_mem += sym_infer(self.mem_estimate, var_vals)
if et is not None: GlobalCounters.time_sum_s += et
def clear_jit_inputs(self):
for (j,i) in self.input_replace.keys(): self.jit_cache[j].rawbufs[i] = None
@@ -193,17 +194,6 @@ class ASTRunner:
CacheCollector.add(self, rawbufs, var_vals if var_vals is not None else {})
return et
def update_stats(self, name, var_vals: Optional[Dict[Variable, int]], et: Optional[float], buf_count, lra, jit):
if var_vals is None: var_vals = {}
op_estimate = sym_infer(self.op_estimate, var_vals)
mem_estimate = sym_infer(self.mem_estimate, var_vals)
if DEBUG >= 2:
print(f"{colored(f'*** {GlobalCounters.kernel_count:4d}', 'magenta' if jit else None)} {name+' '*(37-ansilen(name))} arg {buf_count:3d} sz {str(lra.get('global_size', '')):18s} {str(lra.get('local_size', '')):12s} OPs {int(op_estimate/1e6):6d}M/{GlobalCounters.global_ops/1e9:7.2f}G mem {GlobalCounters.mem_used/1e9:5.2f} GB " +
(str() if et is None else f"tm {et*1e6:9.2f}us/{GlobalCounters.time_sum_s*1e3:9.2f}ms ({op_estimate/((et or 1e-20)*1e9):8.2f} GFLOPS, {mem_estimate/((et or 1e-20)*1e9):7.2f} GB/s)"))
GlobalCounters.kernel_count += 1
GlobalCounters.global_ops += op_estimate
GlobalCounters.global_mem += mem_estimate
def __call__(self, rawbufs:List[Optional[RawBuffer]], var_vals:Optional[Dict[Variable, int]]=None, jit=False, force_wait=False) -> Optional[float]:
raise NotImplementedError("override this")
@@ -218,7 +208,7 @@ class InterpretedASTRunner(ASTRunner):
st = time.perf_counter()
ret: RawBuffer = self.fxn(rawbufs[1:], var_vals)
et = time.perf_counter() - st
self.update_stats(f"<interpreted {ret.size}>", var_vals, et, len(rawbufs), {}, jit)
update_stats(f"<interpreted {ret.size}>", self.op_estimate, self.mem_estimate, var_vals, et, len(rawbufs), jit)
if rawbufs[0] is not None:
assert rawbufs[0].dtype == ret.dtype
rawbufs[0].size = ret.size # NOTE: for symbolic this can change
@@ -226,7 +216,6 @@ class InterpretedASTRunner(ASTRunner):
else: rawbufs[0] = ret
return et
from tinygrad.runtime.interpreted import interpret_ast
class Interpreted:
def __init__(self, buffer: Type[RawBuffer], fxn_for_op:Dict[Op, Callable], from_underlying:Optional[Callable]=None):
self.buffer, self.fxn_for_op, self.from_underlying = buffer, fxn_for_op, from_underlying
@@ -236,11 +225,50 @@ class Interpreted:
self.method_cache: Dict[LazyOp, InterpretedASTRunner] = {}
def exec_ast(self, ast:LazyOp, output:LazyBuffer, inputs:Tuple[LazyBuffer, ...], var_vals:Dict[Variable, int], **kwargs):
if ast not in self.method_cache: self.method_cache[ast] = InterpretedASTRunner(ast, interpret_ast(self.fxn_for_op, self.from_underlying, ast))
if ast not in self.method_cache: self.method_cache[ast] = get_interpreted_fxn(self.fxn_for_op, self.from_underlying, ast)
rawbufs = [output.realized if output.realized is not None else output.output_buffer] + [x.realized for x in inputs]
self.method_cache[ast].exec(rawbufs, var_vals)
output.realized = rawbufs[0]
def get_interpreted_fxn(fxn_for_op:Dict[Op, Callable], from_underlying:Optional[Callable], ast:LazyOp) -> InterpretedASTRunner:
if DEBUG >= 3:
from tinygrad.graph import print_tree
print_tree(ast)
tglob: Dict[str, Any] = {"Variable": Variable}
lines: List[str] = []
@functools.lru_cache(None)
def gstr(x:Any, nm=None) -> str:
if ('Variable' in (str_arg := repr(x)) or 'NumNode' in str_arg):
str_arg = re.sub(r'Variable\(.*?\)', lambda m: f'var_vals[{str(m.group(0))}]', str_arg)
# TODO: (Variable - Variable) might create NumNode. can we remove it?
return re.sub(r'NumNode\((.*?)\)', r'\1', str_arg)
ret = str(nm).replace(".", "_") if nm else f"m{len(tglob):04d}"
tglob[ret] = x
return ret
@functools.lru_cache(None)
def _interpret_ast(ast:LazyOp) -> str:
if TernaryOps.MULACC in fxn_for_op and ast.op == ReduceOps.SUM and isinstance(ast.src[0], LazyOp) and ast.src[0].op == BinaryOps.MUL:
ast = LazyOp(TernaryOps.MULACC, ast.src[0].src, ast.arg)
if ast.op in BufferOps:
tmp = f"{gstr(fxn_for_op[ast.op], ast.op)}({gstr(ast.arg.val)}, {gstr(ast.arg.dtype)})" if ast.op == BufferOps.CONST else f"{gstr(fxn_for_op[ast.op], ast.op)}(inputs[{ast.arg.idx-1}])"
for mop,arg in ast.arg.st.to_movement_ops(): tmp = f"{gstr(fxn_for_op[mop], mop)}({tmp}, {gstr(arg)})"
else:
inp = [_interpret_ast(src) for src in ast.src]
tmp = f"{gstr(fxn_for_op[ast.op], ast.op)}({', '.join(inp + ([gstr(ast.arg)] if ast.arg else []))})"
ret = f"a{len(lines)}"
lines.append(f" {ret} = {tmp}")
return ret
ret = _interpret_ast(ast)
src = '\n'.join(['def run(inputs, var_vals):'] + lines + [f" return {gstr(from_underlying, 'from_underlying')}({ret})" if from_underlying is not None else f" return {ret}"])
if DEBUG >= 4: print(functools.reduce(lambda x,y: (x.replace(y[0], str(y[1])) if y[0][0:2] == "m0" else x), tglob.items(), src))
exec(compile(src, "<ast>", "exec"), tglob) # pylint: disable=exec-used
return InterpretedASTRunner(ast, tglob['run'])
# **************** for Compiled Buffers ****************
class CompiledASTRunner(ASTRunner):
@@ -272,8 +300,8 @@ class CompiledASTRunner(ASTRunner):
lra = self.runtime_args.copy()
if global_size: lra['global_size'] = global_size
if local_size and 'local_size' not in lra: lra['local_size'] = local_size
if et := self.clprg(*rawbufs, *var_vals.values(), **lra, wait=force_wait or DEBUG>=2): GlobalCounters.time_sum_s += et
self.update_stats(self.display_name if self.display_name is not None else self.name, var_vals, et, len(rawbufs), lra, jit)
et = self.clprg(*rawbufs, *var_vals.values(), **lra, wait=force_wait or DEBUG>=2)
update_stats(self.display_name if self.display_name is not None else self.name, self.op_estimate, self.mem_estimate, var_vals, et, len(rawbufs), jit, lra=lra)
return et
class Compiled:

View File

@@ -1,45 +0,0 @@
from typing import Callable, Optional, Dict, List, Any
import functools, re
from tinygrad.helpers import DEBUG
from tinygrad.ops import LazyOp, TernaryOps, ReduceOps, BinaryOps, BufferOps, Op
from tinygrad.shape.symbolic import Variable
from tinygrad.runtime.lib import RawBuffer
def interpret_ast(fxn_for_op:Dict[Op, Callable], from_underlying:Optional[Callable], ast:LazyOp) -> Callable[[List[RawBuffer], Dict[Variable, int]], RawBuffer]:
if DEBUG >= 3:
from tinygrad.graph import print_tree
print_tree(ast)
tglob: Dict[str, Any] = {"Variable": Variable}
lines: List[str] = []
@functools.lru_cache(None)
def gstr(x:Any, nm=None) -> str:
if ('Variable' in (str_arg := repr(x)) or 'NumNode' in str_arg):
str_arg = re.sub(r'Variable\(.*?\)', lambda m: f'var_vals[{str(m.group(0))}]', str_arg)
# TODO: (Variable - Variable) might create NumNode. can we remove it?
return re.sub(r'NumNode\((.*?)\)', r'\1', str_arg)
ret = str(nm).replace(".", "_") if nm else f"m{len(tglob):04d}"
tglob[ret] = x
return ret
@functools.lru_cache(None)
def _interpret_ast(ast:LazyOp) -> str:
if TernaryOps.MULACC in fxn_for_op and ast.op == ReduceOps.SUM and isinstance(ast.src[0], LazyOp) and ast.src[0].op == BinaryOps.MUL:
ast = LazyOp(TernaryOps.MULACC, ast.src[0].src, ast.arg)
if ast.op in BufferOps:
tmp = f"{gstr(fxn_for_op[ast.op], ast.op)}({gstr(ast.arg.val)}, {gstr(ast.arg.dtype)})" if ast.op == BufferOps.CONST else f"{gstr(fxn_for_op[ast.op], ast.op)}(inputs[{ast.arg.idx-1}])"
for mop,arg in ast.arg.st.to_movement_ops(): tmp = f"{gstr(fxn_for_op[mop], mop)}({tmp}, {gstr(arg)})"
else:
inp = [_interpret_ast(src) for src in ast.src]
tmp = f"{gstr(fxn_for_op[ast.op], ast.op)}({', '.join(inp + ([gstr(ast.arg)] if ast.arg else []))})"
ret = f"a{len(lines)}"
lines.append(f" {ret} = {tmp}")
return ret
ret = _interpret_ast(ast)
src = '\n'.join(['def run(inputs, var_vals):'] + lines + [f" return {gstr(from_underlying, 'from_underlying')}({ret})" if from_underlying is not None else f" return {ret}"])
if DEBUG >= 4: print(functools.reduce(lambda x,y: (x.replace(y[0], str(y[1])) if y[0][0:2] == "m0" else x), tglob.items(), src))
exec(compile(src, "<ast>", "exec"), tglob) # pylint: disable=exec-used
return tglob['run']

View File

@@ -4,7 +4,7 @@ import Metal, Cocoa, libdispatch
from typing import List, Any, Tuple, Dict, Union, Set, cast
from tinygrad.codegen.kernel import LinearizerOptions
from tinygrad.helpers import prod, getenv, DEBUG, DType, dtypes, diskcache, dedup
from tinygrad.ops import Compiled, BatchExecutor, JitItem, CompiledASTRunner
from tinygrad.ops import Compiled, BatchExecutor, JitItem, CompiledASTRunner, update_stats
from tinygrad.renderer.metal import MetalRenderer
from tinygrad.runtime.lib import RawBufferMapped, RawBuffer, LRUAllocator
from tinygrad.shape.symbolic import Variable, Node
@@ -149,7 +149,7 @@ class MetalBatchExecutor(BatchExecutor):
else:
METAL.mtl_buffers_in_flight.append(command_buffer)
et = None
super().update_stats(var_vals, et)
update_stats(f"<batched {len(self.jit_cache)}>", self.op_estimate, self.mem_estimate, var_vals, et, buf_count=len(input_rawbuffers), jit=True, num_kernels=len(self.jit_cache))
return et
MetalBuffer = Compiled(RawMetalBuffer, LinearizerOptions(device="METAL"), MetalRenderer, compile_metal, MetalProgram, METAL.synchronize, batch_executor=MetalBatchExecutor if METAL.supports_icb else BatchExecutor)