diff --git a/test/external/fuzz_linearizer.py b/test/external/fuzz_linearizer.py index 80fed448de..87ebd24f4f 100644 --- a/test/external/fuzz_linearizer.py +++ b/test/external/fuzz_linearizer.py @@ -59,7 +59,7 @@ def run_linearizer(lin: Linearizer, rawbufs=None, var_vals=None): return "COMPILE_ERROR" try: - prg(rawbufs, var_vals, wait=True, do_update_stats=False) + prg(rawbufs, var_vals, wait=True) except Exception: traceback.print_exc() return "EXEC_ERROR" diff --git a/tinygrad/device.py b/tinygrad/device.py index d27e5c60f0..e784dd5ff5 100644 --- a/tinygrad/device.py +++ b/tinygrad/device.py @@ -1,9 +1,8 @@ from __future__ import annotations from collections import defaultdict -from typing import TYPE_CHECKING, Any, List, Optional, Dict, Tuple, ClassVar, NamedTuple +from typing import TYPE_CHECKING, Any, List, Optional, Dict, Tuple, ClassVar, NamedTuple, cast import importlib, inspect, functools, pathlib, time, ctypes, os -from tinygrad.helpers import ansilen, prod, getenv, colored, all_int, to_function_name, from_mv, flat_mv, diskcache_get, diskcache_put -from tinygrad.helpers import DEBUG, BEAM, NOOPT, GlobalCounters +from tinygrad.helpers import prod, getenv, colored, all_int, to_function_name, from_mv, flat_mv, diskcache_get, diskcache_put, DEBUG, BEAM, NOOPT from tinygrad.shape.symbolic import Variable, sym_infer, sint from tinygrad.ops import LazyOp, get_lazyop_info from tinygrad.buffer import Buffer, BufferOptions @@ -40,30 +39,20 @@ Device = _Device() # **************** base Runner + helpers **************** class Runner: - def __init__(self): - self.op_estimate:sint = 0 - self.mem_estimate:sint = 0 + def __init__(self, display_name:str, dname:str, op_estimate:sint=0, mem_estimate:sint=0): + self.first_run, self.display_name, self.dname, self.op_estimate, self.mem_estimate = True, display_name, dname, op_estimate, mem_estimate def exec(self, rawbufs:List[Buffer], var_vals:Optional[Dict[Variable, int]]=None) -> Optional[float]: return self(rawbufs, {} if var_vals is None else var_vals) - def __call__(self, rawbufs:List[Buffer], var_vals:Dict[Variable, int], wait=False, jit=False) -> Optional[float]: + def __call__(self, rawbufs:List[Buffer], var_vals:Dict[Variable, int], wait=False) -> Optional[float]: raise NotImplementedError("override this") -def update_stats(name:str, op_estimate:sint, mem_estimate:sint, var_vals: Optional[Dict[Variable, int]], et: Optional[float], buf_count:int, jit=False, num_kernels=1, lra: Optional[Dict]=None, device:str="", first_run=False): # noqa: E501 - if var_vals is None: var_vals = {} - op_estimate = sym_infer(op_estimate, var_vals) - mem_estimate = sym_infer(mem_estimate, var_vals) - GlobalCounters.kernel_count += num_kernels - GlobalCounters.global_ops += op_estimate - GlobalCounters.global_mem += mem_estimate - if et is not None: GlobalCounters.time_sum_s += et - if DEBUG >= 2: - ptm = (colored(f"{et*1e3:9.2f}ms", "yellow") if et > 0.01 else f"{et*1e6:9.2f}us") if et is not None else "" - print(f"{colored(f'*** {device[:7]:7s} {GlobalCounters.kernel_count:4d}', ('magenta' if num_kernels == 1 else 'CYAN') if jit else ('green' if first_run else None))} {name+' '*(38-ansilen(name))} arg {buf_count:3d} mem {GlobalCounters.mem_used/1e9:5.2f} GB " + # noqa: E501 - (str() if et is None else f"tm {ptm}/{GlobalCounters.time_sum_s*1e3:9.2f}ms ({op_estimate/((et or 1e-20)*1e9):8.2f} GFLOPS, {mem_estimate/((et or 1e-20)*1e9):7.2f} GB/s)")) # noqa: E501 - # **************** Buffer / Allocator **************** class BufferCopy(Runner): + def __init__(self, total_sz, dest_device, src_device): + if total_sz >= 1e6: name = f"{type(self).__name__[6:].lower()} {total_sz/1e6:7.2f}M, {dest_device[:7]:>7s} <- {src_device[:7]:7s}" + else: name = f"{type(self).__name__[6:].lower()} {total_sz:8d}, {dest_device[:7]:>7s} <- {src_device[:7]:7s}" + super().__init__(colored(name, "yellow"), dest_device, 0, total_sz) def copy(self, dest, src): if src.device.startswith("DISK") and hasattr(dest.allocator, 'copy_from_fd') and src.nbytes >= 4096 and hasattr(src.allocator.device, 'fd'): dest.allocator.copy_from_fd(dest._buf, src.allocator.device.fd, src._buf.offset, src.nbytes) @@ -72,19 +61,14 @@ class BufferCopy(Runner): src.allocator.copyout(dest.allocator.as_buffer(dest._buf), src._buf) else: dest.copyin(src.as_buffer(allow_zero_copy=True)) # may allocate a CPU buffer depending on allow_zero_copy - def __call__(self, rawbufs:List[Buffer], var_vals:Dict[Variable, int], wait=False, jit=False): + def __call__(self, rawbufs:List[Buffer], var_vals:Dict[Variable, int], wait=False): dest, src = rawbufs[0:2] assert dest.size == src.size and dest.dtype == src.dtype, f"buffer copy mismatch, {dest.size} != {src.size}, {dest.dtype} != {src.dtype}" st = time.perf_counter() self.copy(dest, src) - et = None - if wait or DEBUG >= 2: + if wait: Device[dest.device].synchronize() - et = time.perf_counter() - st - total_sz = dest.size*dest.dtype.itemsize - if total_sz >= 1e6: name = f"{type(self).__name__[6:].lower()} {total_sz/1e6:7.2f}M, {dest.device[:7]:>7s} <- {src.device[:7]:7s}" - else: name = f"{type(self).__name__[6:].lower()} {total_sz:8d}, {dest.device[:7]:>7s} <- {src.device[:7]:7s}" - update_stats(colored(name, "yellow"), 0, total_sz, {}, et, 2, jit, device=dest.device) + return time.perf_counter() - st class BufferXfer(BufferCopy): def copy(self, dest, src): @@ -157,17 +141,15 @@ class Compiler: class CompiledRunner(Runner): def __init__(self, name:str, prg:str, dname:str, global_size:Optional[List[int]]=None, local_size:Optional[List[int]]=None, variables:Optional[List[Variable]]=None, op_estimate:sint=0, mem_estimate:sint=0, precompiled:Optional[bytes]=None, outcount:int=1): - super().__init__() if DEBUG >= 4: print(prg) if global_size is not None: global_size = global_size + [1]*(3-len(global_size)) if local_size is not None: local_size = local_size + [1]*(3-len(local_size)) - self.name, self.display_name, self.prg, self.dname, self.global_size, self.local_size, self.first_run = \ - to_function_name(name), name, prg, dname, global_size, local_size, True - assert self.device.compiler is not None, "compiler is required to make an AST kernel" - lib:bytes = precompiled if precompiled is not None else self.device.compiler.compile_cached(prg) - self.lib, self.clprg, self.outcount = lib, self.device.runtime(self.name, lib), outcount + self.name, self.prg, self.global_size, self.local_size, self.first_run = \ + to_function_name(name), prg, global_size, local_size, True + lib:bytes = precompiled if precompiled is not None else cast(Compiler, Device[dname].compiler).compile_cached(prg) + self.lib, self.clprg, self.outcount = lib, Device[dname].runtime(self.name, lib), outcount self.vars: List[Variable] = [] if variables is None else variables - self.op_estimate, self.mem_estimate = op_estimate, mem_estimate + super().__init__(name, dname, op_estimate, mem_estimate) def to_other_device(self, dname:str): return CompiledRunner(self.display_name, self.prg, dname, self.global_size, self.local_size, @@ -185,7 +167,7 @@ class CompiledRunner(Runner): local_size = [sym_infer(sz, var_vals) for sz in self.local_size] if self.local_size is not None else self.local_size return global_size, local_size - def __call__(self, rawbufs:List[Buffer], var_vals:Dict[Variable, int], wait=False, jit=False, do_update_stats=True) -> Optional[float]: + def __call__(self, rawbufs:List[Buffer], var_vals:Dict[Variable, int], wait=False) -> Optional[float]: global_size, local_size = self.launch_dims(var_vals) if global_size is not None and local_size is None and all_int(self.global_size): # type: ignore[arg-type] # TODO: this is copied from get_program @@ -195,14 +177,10 @@ class CompiledRunner(Runner): lra = {} if global_size: lra['global_size'] = global_size if local_size: lra['local_size'] = local_size - et = self.clprg(*[x._buf for x in rawbufs], **lra, vals=tuple(var_vals[k] for k in self.vars), wait=wait or DEBUG>=2) - if do_update_stats: update_stats(self.display_name, self.op_estimate, self.mem_estimate, var_vals, et, len(rawbufs), jit, - lra=lra, device=self.dname, first_run=self.first_run) - self.first_run = False - return et + return self.clprg(*[x._buf for x in rawbufs], **lra, vals=tuple(var_vals[k] for k in self.vars), wait=wait) class MultiDeviceJITGraph(Runner): - def __call__(self, rawbufs:List[Buffer], var_vals:Dict[Variable, int], wait=False, jit=False) -> Optional[float]: + def __call__(self, rawbufs:List[Buffer], var_vals:Dict[Variable, int], wait=False) -> Optional[float]: raise NotImplementedError("override this") method_cache: Dict[Tuple[str, Tuple[LazyOp, ...], bool], CompiledRunner] = {} diff --git a/tinygrad/engine/jit.py b/tinygrad/engine/jit.py index c4eae0855d..b50301ee3a 100644 --- a/tinygrad/engine/jit.py +++ b/tinygrad/engine/jit.py @@ -109,7 +109,7 @@ class TinyJit(Generic[ReturnType]): # jit exec assert self.expected_names == expected_names and self.expected_lbs == expected_lbs, "args mismatch in JIT" for (j,i),input_idx in self.input_replace.items(): self.jit_cache[j].rawbufs[i] = input_rawbuffers[input_idx] - if DEBUG >= 1: print(f"jit execs {len(self.jit_cache)} kernels") + if DEBUG >= 1 and len(self.jit_cache) >= 10: print(f"jit execs {len(self.jit_cache)} kernels") for ei in self.jit_cache: ei.run(var_vals, jit=True) elif self.cnt == 1: # jit capture diff --git a/tinygrad/engine/realize.py b/tinygrad/engine/realize.py index 8f3820638f..a42b39d8a2 100644 --- a/tinygrad/engine/realize.py +++ b/tinygrad/engine/realize.py @@ -2,28 +2,39 @@ from typing import List, Dict, Optional, cast, Generator, DefaultDict, Tuple, It from collections import defaultdict from dataclasses import dataclass from tinygrad.dtype import DType -from tinygrad.helpers import colored, getenv, dedup, DEBUG +from tinygrad.helpers import colored, getenv, dedup, DEBUG, GlobalCounters, ansilen from tinygrad.ops import ScheduleItem, BufferOps, LoadOps, copy_ast -from tinygrad.device import Runner, Device, BufferCopy, BufferXfer, update_stats +from tinygrad.device import Runner, Device, BufferCopy, BufferXfer from tinygrad.buffer import Buffer -from tinygrad.shape.symbolic import Variable +from tinygrad.shape.symbolic import Variable, sym_infer @dataclass(frozen=True) class ExecItem: prg: Runner rawbufs: List[Optional[Buffer]] - def run(self, var_vals:Optional[Dict[Variable, int]]=None, wait=False, jit=False): - self.prg([cast(Buffer, x).ensure_allocated() for x in self.rawbufs], var_vals if var_vals is not None else {}, wait=wait, jit=jit) + def run(self, var_vals:Optional[Dict[Variable, int]]=None, wait=False, jit=False, do_update_stats=True) -> Optional[float]: + et = self.prg([cast(Buffer, x).ensure_allocated() for x in self.rawbufs], var_vals if var_vals is not None else {}, wait=wait or DEBUG >= 2) + if do_update_stats: + GlobalCounters.kernel_count += 1 + GlobalCounters.global_ops += (op_estimate:=sym_infer(self.prg.op_estimate, var_vals)) + GlobalCounters.global_mem += (mem_estimate:=sym_infer(self.prg.mem_estimate, var_vals)) + if et is not None: GlobalCounters.time_sum_s += et + if DEBUG >= 2: + ptm = (colored(f"{et*1e3:9.2f}ms", "yellow") if et > 0.01 else f"{et*1e6:9.2f}us") if et is not None else "" + print(f"{colored(f'*** {self.prg.dname[:7]:7s} {GlobalCounters.kernel_count:4d}', 'magenta' if jit else ('green' if self.prg.first_run else None))} {self.prg.display_name+' '*(38-ansilen(self.prg.display_name))} arg {len(self.rawbufs):3d} mem {GlobalCounters.mem_used/1e9:5.2f} GB " + # noqa: E501 + (str() if et is None else f"tm {ptm}/{GlobalCounters.time_sum_s*1e3:9.2f}ms ({op_estimate/((et or 1e-20)*1e9):8.2f} GFLOPS, {mem_estimate/((et or 1e-20)*1e9):7.2f} GB/s)")) # noqa: E501 + self.prg.first_run = False + return et class CustomOp(Runner): def __init__(self, fxn): self.fxn = fxn - super().__init__() - def __call__(self, rawbufs:List[Buffer], var_vals:Dict[Variable, int], wait=False, jit=False): self.fxn(*rawbufs) + super().__init__(self.fxn.__name__, "CUSTOM", 0, 0) + def __call__(self, rawbufs:List[Buffer], var_vals:Dict[Variable, int], wait=False): self.fxn(*rawbufs) class EmptyOp(Runner): - def __call__(self, rawbufs:List[Buffer], var_vals:Dict[Variable, int], wait=False, jit=False): - update_stats(colored(f"empty {rawbufs[0].size:10d} {rawbufs[0].dtype}", "yellow"), 0, 0, {}, jit, 1, device=rawbufs[0].device) + def __init__(self, buf:Buffer): super().__init__(colored(f"empty {buf.size:10d} {buf.dtype}", "yellow"), buf.device) + def __call__(self, rawbufs:List[Buffer], var_vals:Dict[Variable, int], wait=False): pass def lower_schedule_item(si:ScheduleItem) -> Runner: assert len(set(x.device for x in si.outputs+si.inputs)) == 1 or si.ast[0].op is LoadOps.COPY @@ -31,11 +42,13 @@ def lower_schedule_item(si:ScheduleItem) -> Runner: assert len(si.ast) == 1 and len(si.outputs) == 1, "only ASTRunner supports multioutput" out, ast = si.outputs[0], si.ast[0] if ast.op is LoadOps.COPY: + kernel_type = BufferCopy if hasattr(Device[out.device].allocator, 'transfer') and out.device.split(":")[0] == si.inputs[0].device.split(":")[0]: - return Device[si.outputs[0].device].get_runner(copy_ast(ast.arg)) if getenv("USE_COPY_KERNEL") else BufferXfer() - return BufferCopy() + if getenv("USE_COPY_KERNEL"): return Device[out.device].get_runner(copy_ast(ast.arg)) + kernel_type = BufferXfer + return kernel_type(ast.arg, out.device, si.inputs[0].device) if ast.op is LoadOps.CUSTOM: return CustomOp(ast.arg) - if ast.op is LoadOps.EMPTY: return EmptyOp() + if ast.op is LoadOps.EMPTY: return EmptyOp(out) raise RuntimeError(f"don't know how to lower {ast}") def lower_schedule(schedule:List[ScheduleItem]) -> Generator[ExecItem, None, None]: diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index 9c6dcabe7e..52293a1e2c 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -262,7 +262,7 @@ def create_schedule_with_vars(outs:List[LazyBuffer], seen:Optional[Set[LazyBuffe # confirm everything was scheduled correctly if not all(degree == 0 for degree in in_degree.values()) or len(prescheduled) != len(schedule): raise RuntimeError(f"cycle detected in graph, prescheduled {len(prescheduled)} but only scheduled {len(schedule)}") - if DEBUG >= 1 and len(schedule) > 0: print(f"scheduled {len(schedule)} kernels") + if DEBUG >= 1 and len(schedule) >= 10: print(f"scheduled {len(schedule)} kernels") return schedule, var_vals def create_schedule(outs:List[LazyBuffer], seen:Optional[Set[LazyBuffer]]=None) -> List[ScheduleItem]: diff --git a/tinygrad/features/search.py b/tinygrad/features/search.py index d8ad27e292..91cdf09f69 100644 --- a/tinygrad/features/search.py +++ b/tinygrad/features/search.py @@ -41,7 +41,7 @@ def _time_program(variables:List[Variable], outcount:int, rdev:Compiled, lib:byt for _ in range(cnt): if clear_l2: with Context(DEBUG=0, BEAM=0, CACHECOLLECTING=0): Tensor.ones(1024,1024).contiguous().realize() - tms.append(cast(float, car(rawbufs, var_vals, wait=True, do_update_stats=False))*factor) + tms.append(cast(float, car(rawbufs, var_vals, wait=True))*factor) if early_stop is not None and early_stop < tms[-1]: break return tms diff --git a/tinygrad/runtime/graph/cuda.py b/tinygrad/runtime/graph/cuda.py index 6eee53d1ce..60e9d2f117 100644 --- a/tinygrad/runtime/graph/cuda.py +++ b/tinygrad/runtime/graph/cuda.py @@ -1,8 +1,8 @@ import ctypes, collections from typing import Any, Optional, Tuple, Dict, List, cast import tinygrad.runtime.autogen.cuda as cuda -from tinygrad.helpers import init_c_var, GraphException, getenv -from tinygrad.device import CompiledRunner, update_stats, Buffer, MultiDeviceJITGraph, BufferXfer, Device, BufferOptions +from tinygrad.helpers import init_c_var, GraphException, getenv, colored +from tinygrad.device import CompiledRunner, Buffer, MultiDeviceJITGraph, BufferXfer, Device, BufferOptions from tinygrad.runtime.ops_cuda import CUDADevice, check, encode_args, cu_time_execution from tinygrad.shape.symbolic import Variable from tinygrad.engine.realize import ExecItem @@ -15,7 +15,6 @@ class CUDAGraph(MultiDeviceJITGraph): self.jit_cache = jit_cache self.input_replace = get_input_replace(jit_cache, input_rawbuffers) - self.op_estimate, self.mem_estimate = get_jit_stats(jit_cache) self.jc_idxs_with_updatable_launch_dims = get_jc_idxs_with_updatable_launch_dims(jit_cache) self.jc_idxs_with_updatable_var_vals = get_jc_idxs_with_updatable_var_vals(jit_cache) self.jc_idxs_with_updatable_rawbufs = list(set([x[0] for x in self.input_replace.keys()])) @@ -70,8 +69,9 @@ class CUDAGraph(MultiDeviceJITGraph): # clear jit inputs to allow their memory to be freed/reused for (j,i) in self.input_replace.keys(): self.jit_cache[j].rawbufs[i] = None + super().__init__(colored(f"", "cyan"), "CUDA", *get_jit_stats(jit_cache)) - def __call__(self, input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int], wait=False, jit=False) -> Optional[float]: + def __call__(self, input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int], wait=False) -> Optional[float]: # Update rawbuffers in the c_args struct. for (j,i),input_idx in self.input_replace.items(): if not self.updatable_nodes[j][3]: setattr(self.updatable_nodes[j][2], f'f{i}', input_rawbuffers[input_idx]._buf) @@ -93,10 +93,7 @@ class CUDAGraph(MultiDeviceJITGraph): if not is_copy: check(cuda.cuGraphExecKernelNodeSetParams(self.instance, node, ctypes.byref(c_node_params))) else: check(cuda.cuGraphExecMemcpyNodeSetParams(self.instance, node, ctypes.byref(c_node_params), c_args)) - et = cu_time_execution(lambda: check(cuda.cuGraphLaunch(self.instance, None)), enable=wait) - update_stats(f"", self.op_estimate, self.mem_estimate, var_vals, et, buf_count=len(input_rawbuffers), - jit=jit, num_kernels=len(self.jit_cache), device="CUDA") - return et + return cu_time_execution(lambda: check(cuda.cuGraphLaunch(self.instance, None)), enable=wait) def __del__(self): if hasattr(self, 'graph'): check(cuda.cuGraphDestroy(self.graph)) diff --git a/tinygrad/runtime/graph/hsa.py b/tinygrad/runtime/graph/hsa.py index 7cefd82c49..8d5de5442f 100644 --- a/tinygrad/runtime/graph/hsa.py +++ b/tinygrad/runtime/graph/hsa.py @@ -1,8 +1,8 @@ import ctypes, collections, time, itertools from typing import List, Any, Dict, cast, Optional, Union, Tuple -from tinygrad.helpers import GraphException, init_c_var, round_up +from tinygrad.helpers import GraphException, init_c_var, round_up, colored from tinygrad.buffer import Buffer, BufferOptions -from tinygrad.device import Compiled, CompiledRunner, BufferXfer, MultiDeviceJITGraph, update_stats, Device +from tinygrad.device import Compiled, CompiledRunner, BufferXfer, MultiDeviceJITGraph, Device from tinygrad.shape.symbolic import Variable from tinygrad.runtime.ops_hsa import HSADevice, PROFILE, Profiler from tinygrad.engine.realize import ExecItem @@ -29,7 +29,6 @@ class HSAGraph(MultiDeviceJITGraph): def __init__(self, jit_cache: List[ExecItem], input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int]): self.jit_cache = jit_cache self.input_replace = get_input_replace(jit_cache, input_rawbuffers) - self.op_estimate, self.mem_estimate = get_jit_stats(jit_cache) #type:ignore self.jc_idxs_with_updatable_launch_dims = get_jc_idxs_with_updatable_launch_dims(jit_cache) self.jc_idxs_with_updatable_var_vals = get_jc_idxs_with_updatable_var_vals(jit_cache) @@ -114,8 +113,9 @@ class HSAGraph(MultiDeviceJITGraph): # clear jit inputs to allow their memory to be freed/reused for (j,i) in self.input_replace.keys(): self.jit_cache[j].rawbufs[i] = None + super().__init__(colored(f"", "cyan"), "HSA", *get_jit_stats(jit_cache)) - def __call__(self, input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int], wait=False, jit=False) -> Optional[float]: + def __call__(self, input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int], wait=False) -> Optional[float]: # Wait and restore signals hsa.hsa_signal_wait_scacquire(self.finish_signal, hsa.HSA_SIGNAL_CONDITION_LT, 1, (1 << 64) - 1, hsa.HSA_WAIT_STATE_ACTIVE) for sig in self.signals_to_reset: hsa.hsa_signal_silent_store_relaxed(sig, 1) @@ -158,8 +158,6 @@ class HSAGraph(MultiDeviceJITGraph): et = time.perf_counter() - st for profdev,profdata in self.profile_info.items(): Profiler.tracked_signals[profdev] += profdata - update_stats(f"", self.op_estimate, self.mem_estimate, var_vals, et, buf_count=len(input_rawbuffers), - jit=jit, num_kernels=len(self.jit_cache), device="HSA") return et def alloc_signal(self, reset_on_start=False, wait_on=None): diff --git a/tinygrad/runtime/graph/metal.py b/tinygrad/runtime/graph/metal.py index 192fa12abf..eeba44c632 100644 --- a/tinygrad/runtime/graph/metal.py +++ b/tinygrad/runtime/graph/metal.py @@ -1,20 +1,19 @@ from typing import List, Any, Dict, cast, Optional import Metal from tinygrad.dtype import dtypes -from tinygrad.helpers import dedup, unwrap2, GraphException -from tinygrad.device import Buffer, CompiledRunner, update_stats +from tinygrad.helpers import dedup, unwrap2, GraphException, colored +from tinygrad.device import Buffer, CompiledRunner, Runner from tinygrad.engine.realize import ExecItem from tinygrad.engine.jit import get_input_replace, get_jit_stats, get_jc_idxs_with_updatable_launch_dims from tinygrad.shape.symbolic import Variable from tinygrad.runtime.ops_metal import MetalDevice, wait_check -class MetalGraph: +class MetalGraph(Runner): def __init__(self, device:MetalDevice, jit_cache: List[ExecItem], input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int]): if not all(isinstance(ji.prg, CompiledRunner) for ji in jit_cache): raise GraphException self.jit_cache = jit_cache self.input_replace = get_input_replace(jit_cache, input_rawbuffers) - self.op_estimate, self.mem_estimate = get_jit_stats(jit_cache) self.jc_idx_with_updatable_launch_dims = get_jc_idxs_with_updatable_launch_dims(jit_cache) self.device: MetalDevice = device @@ -54,8 +53,9 @@ class MetalGraph: # clear jit inputs to allow their memory to be freed/reused for (j,i) in self.input_replace.keys(): self.jit_cache[j].rawbufs[i] = None + super().__init__(colored(f"", "cyan"), device.dname, *get_jit_stats(jit_cache)) - def __call__(self, input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int], wait=False, jit=False) -> Optional[float]: + def __call__(self, input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int], wait=False) -> Optional[float]: # NOTE: you at least can't update the ints if this is running if self.command_buffer is not None and self.command_buffer in self.device.mtl_buffers_in_flight: wait_check(self.command_buffer) all_resources = self.all_resources + [x._buf for x in input_rawbuffers] @@ -74,9 +74,6 @@ class MetalGraph: self.command_buffer = command_buffer if wait: wait_check(command_buffer) - et = command_buffer.GPUEndTime() - command_buffer.GPUStartTime() - else: - self.device.mtl_buffers_in_flight.append(command_buffer) - et = None - update_stats(f"", self.op_estimate, self.mem_estimate, var_vals, et, buf_count=len(input_rawbuffers), jit=jit, num_kernels=len(self.jit_cache)) # noqa: E501 - return et \ No newline at end of file + return command_buffer.GPUEndTime() - command_buffer.GPUStartTime() + self.device.mtl_buffers_in_flight.append(command_buffer) + return None \ No newline at end of file diff --git a/tinygrad/runtime/ops_disk.py b/tinygrad/runtime/ops_disk.py index ef8a7ad887..832ccfcf10 100644 --- a/tinygrad/runtime/ops_disk.py +++ b/tinygrad/runtime/ops_disk.py @@ -52,7 +52,8 @@ class DiskRunner(Runner): assert strides_for_shape(view.shape) == view.strides, "disk tensors don't support strides" self.new_size = prod(view.shape) self.new_offset = view.offset * top_src.arg.dtype.itemsize - def __call__(self, rawbufs:List[Buffer], var_vals:Dict[Any, int], wait=False, jit=False): + super().__init__(f"sz 0x{self.new_size:X} offset 0x{self.new_offset:X}", "DISK") + def __call__(self, rawbufs:List[Buffer], var_vals:Dict[Any, int], wait=False): assert len(rawbufs) == 2 # TODO: this is a terrible hack that should be moved to lazy.py rawbufs[0]._buf.offset = rawbufs[1]._buf.offset+self.new_offset diff --git a/tinygrad/shape/symbolic.py b/tinygrad/shape/symbolic.py index 4973fa8cbd..6e780868f6 100644 --- a/tinygrad/shape/symbolic.py +++ b/tinygrad/shape/symbolic.py @@ -300,9 +300,9 @@ class AndNode(RedNode): return Node.ands(subed) def sym_render(a: Union[Node, int], ops=None, ctx=None) -> str: return str(a) if isinstance(a, int) else a.render(ops, ctx) -def sym_infer(a: Union[Node, int], var_vals: Dict[Variable, int]) -> int: +def sym_infer(a: Union[Node, int], var_vals: Optional[Dict[Variable, int]]) -> int: if isinstance(a, (int, float)): return a - ret = a.substitute({k:NumNode(v) for k, v in var_vals.items()}) + ret = a.substitute({k:NumNode(v) for k, v in var_vals.items()}) if var_vals is not None else a assert isinstance(ret, NumNode), f"sym_infer didn't produce NumNode from {a} with {var_vals}" return ret.b diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index c826a1ea36..13947e94a5 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -1074,7 +1074,6 @@ if IMAGE: # TODO: eventually remove this def custom_random(out:Buffer): Tensor._seed += 1 - if DEBUG >= 2: print(f"*** {out.device} rand seed {Tensor._seed} size {out.size:<15d} dtype {out.dtype}") rng = np.random.default_rng(Tensor._seed) if out.dtype == dtypes.half: rng_np_buffer = (rng.integers(low=0, high=2047, size=out.size) / 2048).astype(np.half, copy=False) else: rng_np_buffer = rng.random(size=out.size, dtype=np.float32).astype(dtype=out.dtype.np, copy=False)