mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-26 15:28:10 -05:00
@@ -1,5 +1,5 @@
|
||||
from __future__ import annotations
|
||||
import importlib, inspect, functools, pathlib, time
|
||||
import importlib, inspect, functools, pathlib, time, re
|
||||
from enum import Enum, auto
|
||||
from typing import TYPE_CHECKING, Union, Type, Tuple, Any, List, Optional, Dict, Callable, Mapping, cast
|
||||
from tinygrad.helpers import ansilen, prod, DEBUG, getenv, GlobalCounters, DType, colored, BEAM, NOOPT, dedup, all_int
|
||||
@@ -134,6 +134,19 @@ def get_lazyop_info(ast:LazyOp) -> FlopCounter:
|
||||
def run_ast(ast): return InterpretedFlopCounter[ast.op](*([run_ast(x) for x in ast.src]+([ast.arg] if ast.arg is not None else [])))
|
||||
return run_ast(ast)
|
||||
|
||||
# **************** GlobalCounters stats ****************
|
||||
|
||||
def update_stats(name, op_estimate, mem_estimate, var_vals: Optional[Dict[Variable, int]], et: Optional[float], buf_count, jit=False, num_kernels=1, lra=None):
|
||||
if var_vals is None: var_vals = {}
|
||||
op_estimate, mem_estimate = sym_infer(op_estimate, var_vals), sym_infer(mem_estimate, var_vals)
|
||||
if DEBUG >= 2:
|
||||
print(f"{colored(f'*** {GlobalCounters.kernel_count:4d}', ('magenta' if num_kernels == 1 else 'CYAN') if jit else None)} {name+' '*(37-ansilen(name))} arg {buf_count:3d} sz {str(lra.get('global_size', '') if lra else ''):18s} {str(lra.get('local_size', '') if lra else ''):12s} OPs {int(op_estimate/1e6):6d}M/{GlobalCounters.global_ops/1e9:7.2f}G mem {GlobalCounters.mem_used/1e9:5.2f} GB " +
|
||||
(str() if et is None else f"tm {et*1e6:9.2f}us/{GlobalCounters.time_sum_s*1e3:9.2f}ms ({op_estimate/((et or 1e-20)*1e9):8.2f} GFLOPS, {mem_estimate/((et or 1e-20)*1e9):7.2f} GB/s)"))
|
||||
GlobalCounters.kernel_count += num_kernels
|
||||
GlobalCounters.global_ops += op_estimate
|
||||
GlobalCounters.global_mem += mem_estimate
|
||||
if et is not None: GlobalCounters.time_sum_s += et
|
||||
|
||||
# **************** batch executor ****************
|
||||
|
||||
@dataclass(frozen=True)
|
||||
@@ -161,18 +174,6 @@ class BatchExecutor:
|
||||
for ji in self.jit_cache: ji.prg(ji.rawbufs, var_vals, jit=True)
|
||||
self.clear_jit_inputs()
|
||||
|
||||
def update_stats(self, var_vals: Dict[Variable, int], et: Optional[float]):
|
||||
# TODO: this is mostly copied from ASTRunner
|
||||
op_estimate = sym_infer(self.op_estimate, var_vals)
|
||||
mem_estimate = sym_infer(self.mem_estimate, var_vals)
|
||||
if DEBUG >= 2:
|
||||
print(f"{colored(f'*** {GlobalCounters.kernel_count:4d}', 'CYAN')} kernels:{len(self.jit_cache):4d} inputs:{len(self.input_replace):3d} {' '.join([f'{k.expr}={v}' for k,v in var_vals.items()])[:50]:50s} OPs {int(op_estimate/1e6):6d}M/{GlobalCounters.global_ops/1e9:7.2f}G mem {GlobalCounters.mem_used/1e9:5.2f} GB " +
|
||||
(str() if et is None else f"tm {et*1e6:9.2f}us/{GlobalCounters.time_sum_s*1e3:9.2f}ms ({op_estimate/((et or 1e-20)*1e9):8.2f} GFLOPS, {mem_estimate/((et or 1e-20)*1e9):7.2f} GB/s)"))
|
||||
GlobalCounters.kernel_count += len(self.jit_cache)
|
||||
GlobalCounters.global_ops += sym_infer(self.op_estimate, var_vals)
|
||||
GlobalCounters.global_mem += sym_infer(self.mem_estimate, var_vals)
|
||||
if et is not None: GlobalCounters.time_sum_s += et
|
||||
|
||||
def clear_jit_inputs(self):
|
||||
for (j,i) in self.input_replace.keys(): self.jit_cache[j].rawbufs[i] = None
|
||||
|
||||
@@ -193,17 +194,6 @@ class ASTRunner:
|
||||
CacheCollector.add(self, rawbufs, var_vals if var_vals is not None else {})
|
||||
return et
|
||||
|
||||
def update_stats(self, name, var_vals: Optional[Dict[Variable, int]], et: Optional[float], buf_count, lra, jit):
|
||||
if var_vals is None: var_vals = {}
|
||||
op_estimate = sym_infer(self.op_estimate, var_vals)
|
||||
mem_estimate = sym_infer(self.mem_estimate, var_vals)
|
||||
if DEBUG >= 2:
|
||||
print(f"{colored(f'*** {GlobalCounters.kernel_count:4d}', 'magenta' if jit else None)} {name+' '*(37-ansilen(name))} arg {buf_count:3d} sz {str(lra.get('global_size', '')):18s} {str(lra.get('local_size', '')):12s} OPs {int(op_estimate/1e6):6d}M/{GlobalCounters.global_ops/1e9:7.2f}G mem {GlobalCounters.mem_used/1e9:5.2f} GB " +
|
||||
(str() if et is None else f"tm {et*1e6:9.2f}us/{GlobalCounters.time_sum_s*1e3:9.2f}ms ({op_estimate/((et or 1e-20)*1e9):8.2f} GFLOPS, {mem_estimate/((et or 1e-20)*1e9):7.2f} GB/s)"))
|
||||
GlobalCounters.kernel_count += 1
|
||||
GlobalCounters.global_ops += op_estimate
|
||||
GlobalCounters.global_mem += mem_estimate
|
||||
|
||||
def __call__(self, rawbufs:List[Optional[RawBuffer]], var_vals:Optional[Dict[Variable, int]]=None, jit=False, force_wait=False) -> Optional[float]:
|
||||
raise NotImplementedError("override this")
|
||||
|
||||
@@ -218,7 +208,7 @@ class InterpretedASTRunner(ASTRunner):
|
||||
st = time.perf_counter()
|
||||
ret: RawBuffer = self.fxn(rawbufs[1:], var_vals)
|
||||
et = time.perf_counter() - st
|
||||
self.update_stats(f"<interpreted {ret.size}>", var_vals, et, len(rawbufs), {}, jit)
|
||||
update_stats(f"<interpreted {ret.size}>", self.op_estimate, self.mem_estimate, var_vals, et, len(rawbufs), jit)
|
||||
if rawbufs[0] is not None:
|
||||
assert rawbufs[0].dtype == ret.dtype
|
||||
rawbufs[0].size = ret.size # NOTE: for symbolic this can change
|
||||
@@ -226,7 +216,6 @@ class InterpretedASTRunner(ASTRunner):
|
||||
else: rawbufs[0] = ret
|
||||
return et
|
||||
|
||||
from tinygrad.runtime.interpreted import interpret_ast
|
||||
class Interpreted:
|
||||
def __init__(self, buffer: Type[RawBuffer], fxn_for_op:Dict[Op, Callable], from_underlying:Optional[Callable]=None):
|
||||
self.buffer, self.fxn_for_op, self.from_underlying = buffer, fxn_for_op, from_underlying
|
||||
@@ -236,11 +225,50 @@ class Interpreted:
|
||||
self.method_cache: Dict[LazyOp, InterpretedASTRunner] = {}
|
||||
|
||||
def exec_ast(self, ast:LazyOp, output:LazyBuffer, inputs:Tuple[LazyBuffer, ...], var_vals:Dict[Variable, int], **kwargs):
|
||||
if ast not in self.method_cache: self.method_cache[ast] = InterpretedASTRunner(ast, interpret_ast(self.fxn_for_op, self.from_underlying, ast))
|
||||
if ast not in self.method_cache: self.method_cache[ast] = get_interpreted_fxn(self.fxn_for_op, self.from_underlying, ast)
|
||||
rawbufs = [output.realized if output.realized is not None else output.output_buffer] + [x.realized for x in inputs]
|
||||
self.method_cache[ast].exec(rawbufs, var_vals)
|
||||
output.realized = rawbufs[0]
|
||||
|
||||
def get_interpreted_fxn(fxn_for_op:Dict[Op, Callable], from_underlying:Optional[Callable], ast:LazyOp) -> InterpretedASTRunner:
|
||||
if DEBUG >= 3:
|
||||
from tinygrad.graph import print_tree
|
||||
print_tree(ast)
|
||||
tglob: Dict[str, Any] = {"Variable": Variable}
|
||||
lines: List[str] = []
|
||||
|
||||
@functools.lru_cache(None)
|
||||
def gstr(x:Any, nm=None) -> str:
|
||||
if ('Variable' in (str_arg := repr(x)) or 'NumNode' in str_arg):
|
||||
str_arg = re.sub(r'Variable\(.*?\)', lambda m: f'var_vals[{str(m.group(0))}]', str_arg)
|
||||
# TODO: (Variable - Variable) might create NumNode. can we remove it?
|
||||
return re.sub(r'NumNode\((.*?)\)', r'\1', str_arg)
|
||||
ret = str(nm).replace(".", "_") if nm else f"m{len(tglob):04d}"
|
||||
tglob[ret] = x
|
||||
return ret
|
||||
|
||||
@functools.lru_cache(None)
|
||||
def _interpret_ast(ast:LazyOp) -> str:
|
||||
if TernaryOps.MULACC in fxn_for_op and ast.op == ReduceOps.SUM and isinstance(ast.src[0], LazyOp) and ast.src[0].op == BinaryOps.MUL:
|
||||
ast = LazyOp(TernaryOps.MULACC, ast.src[0].src, ast.arg)
|
||||
|
||||
if ast.op in BufferOps:
|
||||
tmp = f"{gstr(fxn_for_op[ast.op], ast.op)}({gstr(ast.arg.val)}, {gstr(ast.arg.dtype)})" if ast.op == BufferOps.CONST else f"{gstr(fxn_for_op[ast.op], ast.op)}(inputs[{ast.arg.idx-1}])"
|
||||
for mop,arg in ast.arg.st.to_movement_ops(): tmp = f"{gstr(fxn_for_op[mop], mop)}({tmp}, {gstr(arg)})"
|
||||
else:
|
||||
inp = [_interpret_ast(src) for src in ast.src]
|
||||
tmp = f"{gstr(fxn_for_op[ast.op], ast.op)}({', '.join(inp + ([gstr(ast.arg)] if ast.arg else []))})"
|
||||
|
||||
ret = f"a{len(lines)}"
|
||||
lines.append(f" {ret} = {tmp}")
|
||||
return ret
|
||||
|
||||
ret = _interpret_ast(ast)
|
||||
src = '\n'.join(['def run(inputs, var_vals):'] + lines + [f" return {gstr(from_underlying, 'from_underlying')}({ret})" if from_underlying is not None else f" return {ret}"])
|
||||
if DEBUG >= 4: print(functools.reduce(lambda x,y: (x.replace(y[0], str(y[1])) if y[0][0:2] == "m0" else x), tglob.items(), src))
|
||||
exec(compile(src, "<ast>", "exec"), tglob) # pylint: disable=exec-used
|
||||
return InterpretedASTRunner(ast, tglob['run'])
|
||||
|
||||
# **************** for Compiled Buffers ****************
|
||||
|
||||
class CompiledASTRunner(ASTRunner):
|
||||
@@ -272,8 +300,8 @@ class CompiledASTRunner(ASTRunner):
|
||||
lra = self.runtime_args.copy()
|
||||
if global_size: lra['global_size'] = global_size
|
||||
if local_size and 'local_size' not in lra: lra['local_size'] = local_size
|
||||
if et := self.clprg(*rawbufs, *var_vals.values(), **lra, wait=force_wait or DEBUG>=2): GlobalCounters.time_sum_s += et
|
||||
self.update_stats(self.display_name if self.display_name is not None else self.name, var_vals, et, len(rawbufs), lra, jit)
|
||||
et = self.clprg(*rawbufs, *var_vals.values(), **lra, wait=force_wait or DEBUG>=2)
|
||||
update_stats(self.display_name if self.display_name is not None else self.name, self.op_estimate, self.mem_estimate, var_vals, et, len(rawbufs), jit, lra=lra)
|
||||
return et
|
||||
|
||||
class Compiled:
|
||||
|
||||
@@ -1,45 +0,0 @@
|
||||
from typing import Callable, Optional, Dict, List, Any
|
||||
import functools, re
|
||||
from tinygrad.helpers import DEBUG
|
||||
from tinygrad.ops import LazyOp, TernaryOps, ReduceOps, BinaryOps, BufferOps, Op
|
||||
from tinygrad.shape.symbolic import Variable
|
||||
from tinygrad.runtime.lib import RawBuffer
|
||||
|
||||
def interpret_ast(fxn_for_op:Dict[Op, Callable], from_underlying:Optional[Callable], ast:LazyOp) -> Callable[[List[RawBuffer], Dict[Variable, int]], RawBuffer]:
|
||||
if DEBUG >= 3:
|
||||
from tinygrad.graph import print_tree
|
||||
print_tree(ast)
|
||||
tglob: Dict[str, Any] = {"Variable": Variable}
|
||||
lines: List[str] = []
|
||||
|
||||
@functools.lru_cache(None)
|
||||
def gstr(x:Any, nm=None) -> str:
|
||||
if ('Variable' in (str_arg := repr(x)) or 'NumNode' in str_arg):
|
||||
str_arg = re.sub(r'Variable\(.*?\)', lambda m: f'var_vals[{str(m.group(0))}]', str_arg)
|
||||
# TODO: (Variable - Variable) might create NumNode. can we remove it?
|
||||
return re.sub(r'NumNode\((.*?)\)', r'\1', str_arg)
|
||||
ret = str(nm).replace(".", "_") if nm else f"m{len(tglob):04d}"
|
||||
tglob[ret] = x
|
||||
return ret
|
||||
|
||||
@functools.lru_cache(None)
|
||||
def _interpret_ast(ast:LazyOp) -> str:
|
||||
if TernaryOps.MULACC in fxn_for_op and ast.op == ReduceOps.SUM and isinstance(ast.src[0], LazyOp) and ast.src[0].op == BinaryOps.MUL:
|
||||
ast = LazyOp(TernaryOps.MULACC, ast.src[0].src, ast.arg)
|
||||
|
||||
if ast.op in BufferOps:
|
||||
tmp = f"{gstr(fxn_for_op[ast.op], ast.op)}({gstr(ast.arg.val)}, {gstr(ast.arg.dtype)})" if ast.op == BufferOps.CONST else f"{gstr(fxn_for_op[ast.op], ast.op)}(inputs[{ast.arg.idx-1}])"
|
||||
for mop,arg in ast.arg.st.to_movement_ops(): tmp = f"{gstr(fxn_for_op[mop], mop)}({tmp}, {gstr(arg)})"
|
||||
else:
|
||||
inp = [_interpret_ast(src) for src in ast.src]
|
||||
tmp = f"{gstr(fxn_for_op[ast.op], ast.op)}({', '.join(inp + ([gstr(ast.arg)] if ast.arg else []))})"
|
||||
|
||||
ret = f"a{len(lines)}"
|
||||
lines.append(f" {ret} = {tmp}")
|
||||
return ret
|
||||
|
||||
ret = _interpret_ast(ast)
|
||||
src = '\n'.join(['def run(inputs, var_vals):'] + lines + [f" return {gstr(from_underlying, 'from_underlying')}({ret})" if from_underlying is not None else f" return {ret}"])
|
||||
if DEBUG >= 4: print(functools.reduce(lambda x,y: (x.replace(y[0], str(y[1])) if y[0][0:2] == "m0" else x), tglob.items(), src))
|
||||
exec(compile(src, "<ast>", "exec"), tglob) # pylint: disable=exec-used
|
||||
return tglob['run']
|
||||
@@ -4,7 +4,7 @@ import Metal, Cocoa, libdispatch
|
||||
from typing import List, Any, Tuple, Dict, Union, Set, cast
|
||||
from tinygrad.codegen.kernel import LinearizerOptions
|
||||
from tinygrad.helpers import prod, getenv, DEBUG, DType, dtypes, diskcache, dedup
|
||||
from tinygrad.ops import Compiled, BatchExecutor, JitItem, CompiledASTRunner
|
||||
from tinygrad.ops import Compiled, BatchExecutor, JitItem, CompiledASTRunner, update_stats
|
||||
from tinygrad.renderer.metal import MetalRenderer
|
||||
from tinygrad.runtime.lib import RawBufferMapped, RawBuffer, LRUAllocator
|
||||
from tinygrad.shape.symbolic import Variable, Node
|
||||
@@ -149,7 +149,7 @@ class MetalBatchExecutor(BatchExecutor):
|
||||
else:
|
||||
METAL.mtl_buffers_in_flight.append(command_buffer)
|
||||
et = None
|
||||
super().update_stats(var_vals, et)
|
||||
update_stats(f"<batched {len(self.jit_cache)}>", self.op_estimate, self.mem_estimate, var_vals, et, buf_count=len(input_rawbuffers), jit=True, num_kernels=len(self.jit_cache))
|
||||
return et
|
||||
|
||||
MetalBuffer = Compiled(RawMetalBuffer, LinearizerOptions(device="METAL"), MetalRenderer, compile_metal, MetalProgram, METAL.synchronize, batch_executor=MetalBatchExecutor if METAL.supports_icb else BatchExecutor)
|
||||
|
||||
Reference in New Issue
Block a user