diff --git a/docs/env_vars.md b/docs/env_vars.md index 3ca6c6063f..badb99cf37 100644 --- a/docs/env_vars.md +++ b/docs/env_vars.md @@ -45,10 +45,8 @@ OPT | [1-3] | optimization level BEAM | [#] | number of beams in kernel beam search GRAPH | [1] | create a graph of all operations (requires graphviz) GRAPHPATH | [/path/to] | where to put the generated graph -PRINT_PRG | [1] | print program code IMAGE | [1] | enable 2d specific optimizations FLOAT16 | [1] | use float16 for images instead of float32 -ENABLE_METHOD_CACHE | [1] | enable method cache (this is the default) DISALLOW_ASSIGN | [1] | disallow assignment of tensors CL_EXCLUDE | [name0,name1] | comma-separated list of device names to exclude when using OpenCL GPU backend (like `CL_EXCLUDE=gfx1036`) CL_PLATFORM | [# >= 0] | index of the OpenCL [platform](https://documen.tician.de/pyopencl/runtime_platform.html#pyopencl.Platform) to run on. Defaults to 0. diff --git a/test/test_custom_function.py b/test/test_custom_function.py index 7d4cca0086..b9312a47b1 100644 --- a/test/test_custom_function.py +++ b/test/test_custom_function.py @@ -28,7 +28,7 @@ def atan2_gpu(ret:LazyBuffer, a:LazyBuffer, b:LazyBuffer): return ret.realized def atan2_cpu(ret:LazyBuffer, a:LazyBuffer, b:LazyBuffer): - return Device[ret.device].from_underlying(np.arctan2(a.realized._buf, b.realized._buf)) + return Device[ret.device].buffer.fromCPU(np.arctan2(a.realized._buf, b.realized._buf)) # *** second, we write the ATan2 mlop *** # NOTE: The derivative of atan2 doesn't need a custom op! https://www.liquisearch.com/atan2/derivative diff --git a/tinygrad/ops.py b/tinygrad/ops.py index 57f19e2009..672c6e1f20 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -1,5 +1,5 @@ from __future__ import annotations -import importlib, inspect, functools, pathlib, re +import importlib, inspect, functools, pathlib from enum import Enum, auto from typing import TYPE_CHECKING, Union, Type, Tuple, Any, List, Optional, Dict, Callable, Mapping, cast from tinygrad.helpers import ansilen, prod, DEBUG, getenv, GlobalCounters, DType, colored, BEAM, NOOPT, dedup, all_int @@ -128,7 +128,7 @@ class BatchExecutor: def __call__(self, input_rawbuffers: Dict[Union[int, str], RawBuffer], var_vals: Dict[Variable, int], wait=False): for (j,i),input_name in self.input_replace.items(): self.jit_cache[j].rawbufs[i] = input_rawbuffers[input_name] - for ji in self.jit_cache: ji.prg(cast(List[RawBuffer], ji.rawbufs), {v:var_vals[v] for v in getattr(ji.prg,"vars",[])}, jit=True) + for ji in self.jit_cache: ji.prg(cast(List[RawBuffer], ji.rawbufs), var_vals, jit=True) self.clear_jit_inputs() def update_stats(self, var_vals: Dict[Variable, int], et: Optional[float]): @@ -146,66 +146,6 @@ class BatchExecutor: def clear_jit_inputs(self): for (j,i) in self.input_replace.keys(): self.jit_cache[j].rawbufs[i] = None -# **************** for Interpreted Buffers **************** - -class Interpreted: - def __init__(self, buffer, fxn_for_op: Dict[Op, Callable], from_underlying=None): - self.buffer, self.fxn_for_op, self.from_underlying = buffer, fxn_for_op, from_underlying - self.synchronize = lambda: None - self.batch_executor = BatchExecutor - self.codegen = None - self.method_cache: Dict[LazyOp, Callable] = {} - - def interpret_ast(self:Interpreted, ast:LazyOp) -> Callable: - if DEBUG >= 3: - from tinygrad.graph import print_tree - print_tree(ast) - tglob: Dict[str, Any] = {"Variable": Variable} - lines: List[str] = [] - f = self.fxn_for_op - - @functools.lru_cache(None) - def gstr(x:Any, nm=None) -> str: - if ('Variable' in (str_arg := repr(x)) or 'NumNode' in str_arg): - str_arg = re.sub(r'Variable\(.*?\)', lambda m: f'var_vals[{str(m.group(0))}]', str_arg) - # TODO: (Variable - Variable) might create NumNode. can we remove it? - return re.sub(r'NumNode\((.*?)\)', r'\1', str_arg) - ret = str(nm).replace(".", "_") if nm else f"m{len(tglob):04d}" - tglob[ret] = x - return ret - - @functools.lru_cache(None) - def _interpret_ast(ast:LazyOp) -> str: - if TernaryOps.MULACC in f and ast.op == ReduceOps.SUM and isinstance(ast.src[0], LazyOp) and ast.src[0].op == BinaryOps.MUL: - ast = LazyOp(TernaryOps.MULACC, ast.src[0].src, ast.arg) - - if MovementOps.AS_STRIDED in f and ast.op in BufferOps: - tmp = f"{gstr(f[ast.op], ast.op)}({gstr(ast.arg.val)}, {gstr(ast.arg.dtype)})" if ast.op == BufferOps.CONST else f"{gstr(f[ast.op], ast.op)}(inputs[{ast.arg.idx-1}])" - for mop,arg in ast.arg.st.to_movement_ops(): tmp = f"{gstr(f[mop], mop)}({tmp}, {gstr(arg)})" - else: - inp = [_interpret_ast(src) for src in ast.src] - tmp = f"{gstr(f[ast.op], ast.op)}({', '.join(inp + ([gstr(ast.arg)] if ast.arg else []))})" - - ret = f"a{len(lines)}" - lines.append(f" {ret} = {tmp}") - return ret - - ret = _interpret_ast(ast) - src = '\n'.join(['def run(inputs, var_vals):'] + lines + [f" return {gstr(self.from_underlying, 'from_underlying')}({ret})"]) - if DEBUG >= 4: print(functools.reduce(lambda x,y: (x.replace(y[0], str(y[1])) if y[0][0:2] == "m0" else x), tglob.items(), src)) - exec(compile(src, "", "exec"), tglob) # pylint: disable=exec-used - return tglob['run'] - - def exec_ast(self, ast:LazyOp, output:LazyBuffer, inputs:Tuple[LazyBuffer, ...], var_vals:Dict[Variable, int], **kwargs): - if ast not in self.method_cache: self.method_cache[ast] = self.interpret_ast(ast) - ret = self.method_cache[ast]([x.realized for x in inputs] if inputs else None, var_vals) - assert ret.dtype == output.dtype, f"{ret.dtype} != {output.dtype}" - if output.output_buffer is not None: - assert output.output_buffer.dtype == ret.dtype - output.output_buffer._buf = ret._buf - return output.output_buffer - return ret - # **************** independent FlopCounter **************** @dataclass @@ -234,6 +174,24 @@ def get_lazyop_info(ast:LazyOp) -> FlopCounter: def run_ast(ast): return InterpretedFlopCounter[ast.op](*([run_ast(x) for x in ast.src]+([ast.arg] if ast.arg is not None else []))) return run_ast(ast) +# **************** for Interpreted Buffers **************** + +class Interpreted: + def __init__(self, buffer: Type[RawBuffer], compiler: Callable[[LazyOp], Callable]): + self.buffer, self.compiler = buffer, compiler + self.synchronize = lambda: None + self.batch_executor = BatchExecutor + self.codegen = None + self.method_cache: Dict[LazyOp, Callable] = {} + + def exec_ast(self, ast:LazyOp, output:LazyBuffer, inputs:Tuple[LazyBuffer, ...], var_vals:Dict[Variable, int], **kwargs): + if ast not in self.method_cache: self.method_cache[ast] = self.compiler(ast) + output.realized = output.output_buffer # NOTE: assuming this is the right size and dtype from assign + ret: RawBuffer = self.method_cache[ast]([x.realized for x in inputs] if inputs else None, var_vals) + assert output.dtype == ret.dtype, f"expected {output.dtype}, got {ret.dtype}" + if output.realized is not None: output.realized._buf = ret._buf + else: output.realized = ret + # **************** for Compiled Buffers **************** class ASTRunner: @@ -247,7 +205,7 @@ class ASTRunner: self.clprg = runtime(self.name, self.lib) return self - def exec(self, rawbufs, var_vals:Optional[Dict[Variable, int]]=None, force_wait=False) -> Optional[float]: + def exec(self, rawbufs:List[RawBuffer], var_vals:Optional[Dict[Variable, int]]=None, force_wait=False) -> Optional[float]: from tinygrad.jit import CacheCollector CacheCollector.add(self, rawbufs, var_vals if var_vals is not None else {}) return self(rawbufs, var_vals, force_wait=force_wait) @@ -259,6 +217,7 @@ class ASTRunner: def __call__(self, rawbufs:List[RawBuffer], var_vals:Optional[Dict[Variable, int]]=None, jit=False, force_wait=False) -> Optional[float]: if var_vals is None: var_vals = {} + var_vals = {k:var_vals[k] for k in self.vars} # filter the var_vals global_size, local_size = self.launch_dims(var_vals) if global_size is not None and local_size is None and all_int(self.global_size): # type: ignore[arg-type] # TODO: this is copied from get_program @@ -327,7 +286,8 @@ class Compiled: def exec_ast(self, ast:LazyOp, output:LazyBuffer, inputs:Tuple[LazyBuffer, ...], var_vals:Dict[Variable, int], **kwargs): # check if we can reuse the output buffer # if it's aliased, don't use it - # NOTE: this is pretty wrong actually, who knows where else this buffer is used? + # TODO: this is pretty wrong actually, who knows where else this buffer is used? + # TODO: what if an assign is required? this silently is wrong output.realized = output.output_buffer if output.realized is not None: for i,a in enumerate(inputs): @@ -345,13 +305,5 @@ class Compiled: # all the rawbuffers rawbuffers = [output.realized] + [x.realized for x in inputs] - if getenv("ENABLE_METHOD_CACHE", 1): - if ast not in self.method_cache: self.method_cache[ast] = self.get_optimized_program(ast, rawbuffers) - prg = self.method_cache[ast] - else: - prg = self.get_optimized_program(ast, rawbuffers) - - if prg.name == getenv("PRINT_PRG", ''): print(prg.prg) - - prg.exec(rawbuffers, var_vals={k:var_vals[k] for k in prg.vars}) - return output.realized + if ast not in self.method_cache: self.method_cache[ast] = self.get_optimized_program(ast, rawbuffers) + self.method_cache[ast].exec(rawbuffers, var_vals) diff --git a/tinygrad/realize.py b/tinygrad/realize.py index 64e9c5c1de..b8b4a67c05 100644 --- a/tinygrad/realize.py +++ b/tinygrad/realize.py @@ -23,7 +23,7 @@ def run_schedule(schedule:List[ScheduleItem], disable_logging=False): for i,s in enumerate(si.ast.src): assert isinstance(s, LazyOp) and s.op == BufferOps.MEM and s.arg.idx == i+1 and s.arg.st.contiguous, f"bad LoadOps src {i}: {s}" LOAD_OPS_DISPATCHER[cast(LoadOps, si.ast.op)](si.out, *si.inputs) else: - si.out.realized = Device[si.out.device].exec_ast(si.ast, output=si.out, inputs=si.inputs, var_vals=si.var_vals, **si.out._device_extra_args()) + Device[si.out.device].exec_ast(si.ast, output=si.out, inputs=si.inputs, var_vals=si.var_vals, **si.out._device_extra_args()) del si.out.op for v in si.out.views: del v.op assert si.out.realized and isinstance(si.out.realized, Device[si.out.device].buffer), f"device mismatch on realized got {type(si.out.realized)} expected {si.out.device}" diff --git a/tinygrad/runtime/interpreted.py b/tinygrad/runtime/interpreted.py new file mode 100644 index 0000000000..06fda8e734 --- /dev/null +++ b/tinygrad/runtime/interpreted.py @@ -0,0 +1,45 @@ +from typing import Callable, Optional, Dict, List, Any +import functools, re +from tinygrad.helpers import DEBUG +from tinygrad.ops import LazyOp, TernaryOps, ReduceOps, BinaryOps, BufferOps, Op +from tinygrad.shape.symbolic import Variable +from tinygrad.runtime.lib import RawBuffer + +def interpret_ast(fxn_for_op:Dict[Op, Callable], from_underlying:Optional[Callable], ast:LazyOp) -> Callable[[List[RawBuffer], Dict[Variable, int]], RawBuffer]: + if DEBUG >= 3: + from tinygrad.graph import print_tree + print_tree(ast) + tglob: Dict[str, Any] = {"Variable": Variable} + lines: List[str] = [] + + @functools.lru_cache(None) + def gstr(x:Any, nm=None) -> str: + if ('Variable' in (str_arg := repr(x)) or 'NumNode' in str_arg): + str_arg = re.sub(r'Variable\(.*?\)', lambda m: f'var_vals[{str(m.group(0))}]', str_arg) + # TODO: (Variable - Variable) might create NumNode. can we remove it? + return re.sub(r'NumNode\((.*?)\)', r'\1', str_arg) + ret = str(nm).replace(".", "_") if nm else f"m{len(tglob):04d}" + tglob[ret] = x + return ret + + @functools.lru_cache(None) + def _interpret_ast(ast:LazyOp) -> str: + if TernaryOps.MULACC in fxn_for_op and ast.op == ReduceOps.SUM and isinstance(ast.src[0], LazyOp) and ast.src[0].op == BinaryOps.MUL: + ast = LazyOp(TernaryOps.MULACC, ast.src[0].src, ast.arg) + + if ast.op in BufferOps: + tmp = f"{gstr(fxn_for_op[ast.op], ast.op)}({gstr(ast.arg.val)}, {gstr(ast.arg.dtype)})" if ast.op == BufferOps.CONST else f"{gstr(fxn_for_op[ast.op], ast.op)}(inputs[{ast.arg.idx-1}])" + for mop,arg in ast.arg.st.to_movement_ops(): tmp = f"{gstr(fxn_for_op[mop], mop)}({tmp}, {gstr(arg)})" + else: + inp = [_interpret_ast(src) for src in ast.src] + tmp = f"{gstr(fxn_for_op[ast.op], ast.op)}({', '.join(inp + ([gstr(ast.arg)] if ast.arg else []))})" + + ret = f"a{len(lines)}" + lines.append(f" {ret} = {tmp}") + return ret + + ret = _interpret_ast(ast) + src = '\n'.join(['def run(inputs, var_vals):'] + lines + [f" return {gstr(from_underlying, 'from_underlying')}({ret})" if from_underlying is not None else f" return {ret}"]) + if DEBUG >= 4: print(functools.reduce(lambda x,y: (x.replace(y[0], str(y[1])) if y[0][0:2] == "m0" else x), tglob.items(), src)) + exec(compile(src, "", "exec"), tglob) # pylint: disable=exec-used + return tglob['run'] \ No newline at end of file diff --git a/tinygrad/runtime/ops_cpu.py b/tinygrad/runtime/ops_cpu.py index e4267fde7f..43c75d5322 100644 --- a/tinygrad/runtime/ops_cpu.py +++ b/tinygrad/runtime/ops_cpu.py @@ -1,9 +1,10 @@ import numpy as np -import operator +import operator, functools from typing import Callable, Dict, Tuple, Optional from tinygrad.helpers import dtypes, DType from tinygrad.ops import BufferOps, UnaryOps, BinaryOps, MovementOps, ReduceOps, TernaryOps, Op, Interpreted from tinygrad.runtime.lib import RawBuffer +from tinygrad.runtime.interpreted import interpret_ast def shape_to_axis(old_shape:Tuple[int, ...], new_shape:Tuple[int, ...]) -> Tuple[int, ...]: assert len(old_shape) == len(new_shape), "reduce shapes must have same dimensions" @@ -51,4 +52,4 @@ class RawNumpyBuffer(RawBuffer): @classmethod def fromCPU(cls, x): return cls(x.size, dtypes.from_np(x.dtype), x) def toCPU(self): return self._buf -CPUBuffer = Interpreted(RawNumpyBuffer, numpy_fxn_for_op, from_underlying=RawNumpyBuffer.fromCPU) +CPUBuffer = Interpreted(RawNumpyBuffer, functools.partial(interpret_ast, numpy_fxn_for_op, RawNumpyBuffer.fromCPU)) diff --git a/tinygrad/runtime/ops_disk.py b/tinygrad/runtime/ops_disk.py index 583fec05f5..afef209abb 100644 --- a/tinygrad/runtime/ops_disk.py +++ b/tinygrad/runtime/ops_disk.py @@ -1,8 +1,9 @@ -import os, mmap +import os, mmap, functools from typing import Optional from typing import Callable, Dict, Tuple from tinygrad.helpers import prod, DType from tinygrad.runtime.lib import RawBufferMapped +from tinygrad.runtime.interpreted import interpret_ast from tinygrad.ops import Interpreted, Op, MovementOps, UnaryOps, BufferOps class RawDiskBuffer(RawBufferMapped): @@ -38,4 +39,4 @@ class RawDiskBuffer(RawBufferMapped): self._buf[0].readinto(buf) disk_fxn_for_op: Dict[Op, Callable] = { BufferOps.MEM: lambda x: x, UnaryOps.NOOP: lambda x: x, UnaryOps.CAST: RawDiskBuffer.cast, MovementOps.AS_STRIDED: RawDiskBuffer.as_strided } -DiskBuffer = Interpreted(RawDiskBuffer, disk_fxn_for_op, from_underlying=lambda x:x) \ No newline at end of file +DiskBuffer = Interpreted(RawDiskBuffer, functools.partial(interpret_ast, disk_fxn_for_op, None)) diff --git a/tinygrad/runtime/ops_shm.py b/tinygrad/runtime/ops_shm.py index 0ebdfe904d..a4274cd4d6 100644 --- a/tinygrad/runtime/ops_shm.py +++ b/tinygrad/runtime/ops_shm.py @@ -1,10 +1,11 @@ -import os, mmap +import os, mmap, functools try: import _posixshmem except Exception: pass from typing import Callable, Dict from tinygrad.helpers import DType, OSX from tinygrad.runtime.lib import RawBufferMapped from tinygrad.ops import Interpreted, Op, UnaryOps, MovementOps, BufferOps +from tinygrad.runtime.interpreted import interpret_ast class RawShmBuffer(RawBufferMapped): def __init__(self, size, dtype:DType, device:str): @@ -25,4 +26,4 @@ class RawShmBuffer(RawBufferMapped): # TODO: is this wrong? shm_fxn_for_op: Dict[Op, Callable] = { BufferOps.MEM: lambda x: x, UnaryOps.NOOP: lambda x:x, MovementOps.RESHAPE: lambda x,_:x, MovementOps.AS_STRIDED: lambda x,_:x } -ShmBuffer = Interpreted(RawShmBuffer, shm_fxn_for_op, from_underlying=lambda x:x) +ShmBuffer = Interpreted(RawShmBuffer, functools.partial(interpret_ast, shm_fxn_for_op, None)) diff --git a/tinygrad/runtime/ops_torch.py b/tinygrad/runtime/ops_torch.py index fd159cd7e2..3f9ed86b9e 100644 --- a/tinygrad/runtime/ops_torch.py +++ b/tinygrad/runtime/ops_torch.py @@ -1,10 +1,11 @@ -import torch +import torch, functools import numpy as np from typing import Dict, Callable, Optional from tinygrad.ops import BufferOps, UnaryOps, BinaryOps, MovementOps, TernaryOps, Op, Interpreted from tinygrad.helpers import getenv, dtypes, prod, DType from tinygrad.runtime.ops_cpu import base_fxn_for_op, einsum_mulacc from tinygrad.runtime.lib import RawBuffer +from tinygrad.runtime.interpreted import interpret_ast device = torch.device("cuda:0" if torch.cuda.is_available() else ("mps" if getenv("MPS", 0) else "cpu")) type_map = {torch.float64: dtypes.float64, torch.float16: dtypes.float16, torch.float32: dtypes.float32, torch.int8: dtypes.int8, torch.int32: dtypes.int32, torch.int64: dtypes.int64, torch.uint8: dtypes.uint8, torch.bool: dtypes.bool, torch.int16: dtypes.int16} @@ -48,4 +49,4 @@ class RawTorchBuffer(RawBuffer): buf = torch.from_numpy(x if all(s>=0 for s in x.strides) else x.copy()).requires_grad_(False).to(device) return cls(prod(x.shape), type_map[buf.dtype], buf) def toCPU(self): return self._buf.cpu().numpy() -TorchBuffer = Interpreted(RawTorchBuffer, torch_fxn_for_op, from_underlying=lambda x: RawTorchBuffer(prod(x.shape), type_map[x.dtype], x)) +TorchBuffer = Interpreted(RawTorchBuffer, functools.partial(interpret_ast, torch_fxn_for_op, lambda x: RawTorchBuffer(prod(x.shape), type_map[x.dtype], x)))