mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 06:58:11 -05:00
Interpreted cleanups (#2312)
* move the compiler out of ops * don't return realized * var_vals filter, fix custom * typing
This commit is contained in:
@@ -45,10 +45,8 @@ OPT | [1-3] | optimization level
|
||||
BEAM | [#] | number of beams in kernel beam search
|
||||
GRAPH | [1] | create a graph of all operations (requires graphviz)
|
||||
GRAPHPATH | [/path/to] | where to put the generated graph
|
||||
PRINT_PRG | [1] | print program code
|
||||
IMAGE | [1] | enable 2d specific optimizations
|
||||
FLOAT16 | [1] | use float16 for images instead of float32
|
||||
ENABLE_METHOD_CACHE | [1] | enable method cache (this is the default)
|
||||
DISALLOW_ASSIGN | [1] | disallow assignment of tensors
|
||||
CL_EXCLUDE | [name0,name1] | comma-separated list of device names to exclude when using OpenCL GPU backend (like `CL_EXCLUDE=gfx1036`)
|
||||
CL_PLATFORM | [# >= 0] | index of the OpenCL [platform](https://documen.tician.de/pyopencl/runtime_platform.html#pyopencl.Platform) to run on. Defaults to 0.
|
||||
|
||||
@@ -28,7 +28,7 @@ def atan2_gpu(ret:LazyBuffer, a:LazyBuffer, b:LazyBuffer):
|
||||
return ret.realized
|
||||
|
||||
def atan2_cpu(ret:LazyBuffer, a:LazyBuffer, b:LazyBuffer):
|
||||
return Device[ret.device].from_underlying(np.arctan2(a.realized._buf, b.realized._buf))
|
||||
return Device[ret.device].buffer.fromCPU(np.arctan2(a.realized._buf, b.realized._buf))
|
||||
|
||||
# *** second, we write the ATan2 mlop ***
|
||||
# NOTE: The derivative of atan2 doesn't need a custom op! https://www.liquisearch.com/atan2/derivative
|
||||
|
||||
100
tinygrad/ops.py
100
tinygrad/ops.py
@@ -1,5 +1,5 @@
|
||||
from __future__ import annotations
|
||||
import importlib, inspect, functools, pathlib, re
|
||||
import importlib, inspect, functools, pathlib
|
||||
from enum import Enum, auto
|
||||
from typing import TYPE_CHECKING, Union, Type, Tuple, Any, List, Optional, Dict, Callable, Mapping, cast
|
||||
from tinygrad.helpers import ansilen, prod, DEBUG, getenv, GlobalCounters, DType, colored, BEAM, NOOPT, dedup, all_int
|
||||
@@ -128,7 +128,7 @@ class BatchExecutor:
|
||||
|
||||
def __call__(self, input_rawbuffers: Dict[Union[int, str], RawBuffer], var_vals: Dict[Variable, int], wait=False):
|
||||
for (j,i),input_name in self.input_replace.items(): self.jit_cache[j].rawbufs[i] = input_rawbuffers[input_name]
|
||||
for ji in self.jit_cache: ji.prg(cast(List[RawBuffer], ji.rawbufs), {v:var_vals[v] for v in getattr(ji.prg,"vars",[])}, jit=True)
|
||||
for ji in self.jit_cache: ji.prg(cast(List[RawBuffer], ji.rawbufs), var_vals, jit=True)
|
||||
self.clear_jit_inputs()
|
||||
|
||||
def update_stats(self, var_vals: Dict[Variable, int], et: Optional[float]):
|
||||
@@ -146,66 +146,6 @@ class BatchExecutor:
|
||||
def clear_jit_inputs(self):
|
||||
for (j,i) in self.input_replace.keys(): self.jit_cache[j].rawbufs[i] = None
|
||||
|
||||
# **************** for Interpreted Buffers ****************
|
||||
|
||||
class Interpreted:
|
||||
def __init__(self, buffer, fxn_for_op: Dict[Op, Callable], from_underlying=None):
|
||||
self.buffer, self.fxn_for_op, self.from_underlying = buffer, fxn_for_op, from_underlying
|
||||
self.synchronize = lambda: None
|
||||
self.batch_executor = BatchExecutor
|
||||
self.codegen = None
|
||||
self.method_cache: Dict[LazyOp, Callable] = {}
|
||||
|
||||
def interpret_ast(self:Interpreted, ast:LazyOp) -> Callable:
|
||||
if DEBUG >= 3:
|
||||
from tinygrad.graph import print_tree
|
||||
print_tree(ast)
|
||||
tglob: Dict[str, Any] = {"Variable": Variable}
|
||||
lines: List[str] = []
|
||||
f = self.fxn_for_op
|
||||
|
||||
@functools.lru_cache(None)
|
||||
def gstr(x:Any, nm=None) -> str:
|
||||
if ('Variable' in (str_arg := repr(x)) or 'NumNode' in str_arg):
|
||||
str_arg = re.sub(r'Variable\(.*?\)', lambda m: f'var_vals[{str(m.group(0))}]', str_arg)
|
||||
# TODO: (Variable - Variable) might create NumNode. can we remove it?
|
||||
return re.sub(r'NumNode\((.*?)\)', r'\1', str_arg)
|
||||
ret = str(nm).replace(".", "_") if nm else f"m{len(tglob):04d}"
|
||||
tglob[ret] = x
|
||||
return ret
|
||||
|
||||
@functools.lru_cache(None)
|
||||
def _interpret_ast(ast:LazyOp) -> str:
|
||||
if TernaryOps.MULACC in f and ast.op == ReduceOps.SUM and isinstance(ast.src[0], LazyOp) and ast.src[0].op == BinaryOps.MUL:
|
||||
ast = LazyOp(TernaryOps.MULACC, ast.src[0].src, ast.arg)
|
||||
|
||||
if MovementOps.AS_STRIDED in f and ast.op in BufferOps:
|
||||
tmp = f"{gstr(f[ast.op], ast.op)}({gstr(ast.arg.val)}, {gstr(ast.arg.dtype)})" if ast.op == BufferOps.CONST else f"{gstr(f[ast.op], ast.op)}(inputs[{ast.arg.idx-1}])"
|
||||
for mop,arg in ast.arg.st.to_movement_ops(): tmp = f"{gstr(f[mop], mop)}({tmp}, {gstr(arg)})"
|
||||
else:
|
||||
inp = [_interpret_ast(src) for src in ast.src]
|
||||
tmp = f"{gstr(f[ast.op], ast.op)}({', '.join(inp + ([gstr(ast.arg)] if ast.arg else []))})"
|
||||
|
||||
ret = f"a{len(lines)}"
|
||||
lines.append(f" {ret} = {tmp}")
|
||||
return ret
|
||||
|
||||
ret = _interpret_ast(ast)
|
||||
src = '\n'.join(['def run(inputs, var_vals):'] + lines + [f" return {gstr(self.from_underlying, 'from_underlying')}({ret})"])
|
||||
if DEBUG >= 4: print(functools.reduce(lambda x,y: (x.replace(y[0], str(y[1])) if y[0][0:2] == "m0" else x), tglob.items(), src))
|
||||
exec(compile(src, "<ast>", "exec"), tglob) # pylint: disable=exec-used
|
||||
return tglob['run']
|
||||
|
||||
def exec_ast(self, ast:LazyOp, output:LazyBuffer, inputs:Tuple[LazyBuffer, ...], var_vals:Dict[Variable, int], **kwargs):
|
||||
if ast not in self.method_cache: self.method_cache[ast] = self.interpret_ast(ast)
|
||||
ret = self.method_cache[ast]([x.realized for x in inputs] if inputs else None, var_vals)
|
||||
assert ret.dtype == output.dtype, f"{ret.dtype} != {output.dtype}"
|
||||
if output.output_buffer is not None:
|
||||
assert output.output_buffer.dtype == ret.dtype
|
||||
output.output_buffer._buf = ret._buf
|
||||
return output.output_buffer
|
||||
return ret
|
||||
|
||||
# **************** independent FlopCounter ****************
|
||||
|
||||
@dataclass
|
||||
@@ -234,6 +174,24 @@ def get_lazyop_info(ast:LazyOp) -> FlopCounter:
|
||||
def run_ast(ast): return InterpretedFlopCounter[ast.op](*([run_ast(x) for x in ast.src]+([ast.arg] if ast.arg is not None else [])))
|
||||
return run_ast(ast)
|
||||
|
||||
# **************** for Interpreted Buffers ****************
|
||||
|
||||
class Interpreted:
|
||||
def __init__(self, buffer: Type[RawBuffer], compiler: Callable[[LazyOp], Callable]):
|
||||
self.buffer, self.compiler = buffer, compiler
|
||||
self.synchronize = lambda: None
|
||||
self.batch_executor = BatchExecutor
|
||||
self.codegen = None
|
||||
self.method_cache: Dict[LazyOp, Callable] = {}
|
||||
|
||||
def exec_ast(self, ast:LazyOp, output:LazyBuffer, inputs:Tuple[LazyBuffer, ...], var_vals:Dict[Variable, int], **kwargs):
|
||||
if ast not in self.method_cache: self.method_cache[ast] = self.compiler(ast)
|
||||
output.realized = output.output_buffer # NOTE: assuming this is the right size and dtype from assign
|
||||
ret: RawBuffer = self.method_cache[ast]([x.realized for x in inputs] if inputs else None, var_vals)
|
||||
assert output.dtype == ret.dtype, f"expected {output.dtype}, got {ret.dtype}"
|
||||
if output.realized is not None: output.realized._buf = ret._buf
|
||||
else: output.realized = ret
|
||||
|
||||
# **************** for Compiled Buffers ****************
|
||||
|
||||
class ASTRunner:
|
||||
@@ -247,7 +205,7 @@ class ASTRunner:
|
||||
self.clprg = runtime(self.name, self.lib)
|
||||
return self
|
||||
|
||||
def exec(self, rawbufs, var_vals:Optional[Dict[Variable, int]]=None, force_wait=False) -> Optional[float]:
|
||||
def exec(self, rawbufs:List[RawBuffer], var_vals:Optional[Dict[Variable, int]]=None, force_wait=False) -> Optional[float]:
|
||||
from tinygrad.jit import CacheCollector
|
||||
CacheCollector.add(self, rawbufs, var_vals if var_vals is not None else {})
|
||||
return self(rawbufs, var_vals, force_wait=force_wait)
|
||||
@@ -259,6 +217,7 @@ class ASTRunner:
|
||||
|
||||
def __call__(self, rawbufs:List[RawBuffer], var_vals:Optional[Dict[Variable, int]]=None, jit=False, force_wait=False) -> Optional[float]:
|
||||
if var_vals is None: var_vals = {}
|
||||
var_vals = {k:var_vals[k] for k in self.vars} # filter the var_vals
|
||||
global_size, local_size = self.launch_dims(var_vals)
|
||||
if global_size is not None and local_size is None and all_int(self.global_size): # type: ignore[arg-type]
|
||||
# TODO: this is copied from get_program
|
||||
@@ -327,7 +286,8 @@ class Compiled:
|
||||
def exec_ast(self, ast:LazyOp, output:LazyBuffer, inputs:Tuple[LazyBuffer, ...], var_vals:Dict[Variable, int], **kwargs):
|
||||
# check if we can reuse the output buffer
|
||||
# if it's aliased, don't use it
|
||||
# NOTE: this is pretty wrong actually, who knows where else this buffer is used?
|
||||
# TODO: this is pretty wrong actually, who knows where else this buffer is used?
|
||||
# TODO: what if an assign is required? this silently is wrong
|
||||
output.realized = output.output_buffer
|
||||
if output.realized is not None:
|
||||
for i,a in enumerate(inputs):
|
||||
@@ -345,13 +305,5 @@ class Compiled:
|
||||
# all the rawbuffers
|
||||
rawbuffers = [output.realized] + [x.realized for x in inputs]
|
||||
|
||||
if getenv("ENABLE_METHOD_CACHE", 1):
|
||||
if ast not in self.method_cache: self.method_cache[ast] = self.get_optimized_program(ast, rawbuffers)
|
||||
prg = self.method_cache[ast]
|
||||
else:
|
||||
prg = self.get_optimized_program(ast, rawbuffers)
|
||||
|
||||
if prg.name == getenv("PRINT_PRG", ''): print(prg.prg)
|
||||
|
||||
prg.exec(rawbuffers, var_vals={k:var_vals[k] for k in prg.vars})
|
||||
return output.realized
|
||||
if ast not in self.method_cache: self.method_cache[ast] = self.get_optimized_program(ast, rawbuffers)
|
||||
self.method_cache[ast].exec(rawbuffers, var_vals)
|
||||
|
||||
@@ -23,7 +23,7 @@ def run_schedule(schedule:List[ScheduleItem], disable_logging=False):
|
||||
for i,s in enumerate(si.ast.src): assert isinstance(s, LazyOp) and s.op == BufferOps.MEM and s.arg.idx == i+1 and s.arg.st.contiguous, f"bad LoadOps src {i}: {s}"
|
||||
LOAD_OPS_DISPATCHER[cast(LoadOps, si.ast.op)](si.out, *si.inputs)
|
||||
else:
|
||||
si.out.realized = Device[si.out.device].exec_ast(si.ast, output=si.out, inputs=si.inputs, var_vals=si.var_vals, **si.out._device_extra_args())
|
||||
Device[si.out.device].exec_ast(si.ast, output=si.out, inputs=si.inputs, var_vals=si.var_vals, **si.out._device_extra_args())
|
||||
del si.out.op
|
||||
for v in si.out.views: del v.op
|
||||
assert si.out.realized and isinstance(si.out.realized, Device[si.out.device].buffer), f"device mismatch on realized got {type(si.out.realized)} expected {si.out.device}"
|
||||
|
||||
45
tinygrad/runtime/interpreted.py
Normal file
45
tinygrad/runtime/interpreted.py
Normal file
@@ -0,0 +1,45 @@
|
||||
from typing import Callable, Optional, Dict, List, Any
|
||||
import functools, re
|
||||
from tinygrad.helpers import DEBUG
|
||||
from tinygrad.ops import LazyOp, TernaryOps, ReduceOps, BinaryOps, BufferOps, Op
|
||||
from tinygrad.shape.symbolic import Variable
|
||||
from tinygrad.runtime.lib import RawBuffer
|
||||
|
||||
def interpret_ast(fxn_for_op:Dict[Op, Callable], from_underlying:Optional[Callable], ast:LazyOp) -> Callable[[List[RawBuffer], Dict[Variable, int]], RawBuffer]:
|
||||
if DEBUG >= 3:
|
||||
from tinygrad.graph import print_tree
|
||||
print_tree(ast)
|
||||
tglob: Dict[str, Any] = {"Variable": Variable}
|
||||
lines: List[str] = []
|
||||
|
||||
@functools.lru_cache(None)
|
||||
def gstr(x:Any, nm=None) -> str:
|
||||
if ('Variable' in (str_arg := repr(x)) or 'NumNode' in str_arg):
|
||||
str_arg = re.sub(r'Variable\(.*?\)', lambda m: f'var_vals[{str(m.group(0))}]', str_arg)
|
||||
# TODO: (Variable - Variable) might create NumNode. can we remove it?
|
||||
return re.sub(r'NumNode\((.*?)\)', r'\1', str_arg)
|
||||
ret = str(nm).replace(".", "_") if nm else f"m{len(tglob):04d}"
|
||||
tglob[ret] = x
|
||||
return ret
|
||||
|
||||
@functools.lru_cache(None)
|
||||
def _interpret_ast(ast:LazyOp) -> str:
|
||||
if TernaryOps.MULACC in fxn_for_op and ast.op == ReduceOps.SUM and isinstance(ast.src[0], LazyOp) and ast.src[0].op == BinaryOps.MUL:
|
||||
ast = LazyOp(TernaryOps.MULACC, ast.src[0].src, ast.arg)
|
||||
|
||||
if ast.op in BufferOps:
|
||||
tmp = f"{gstr(fxn_for_op[ast.op], ast.op)}({gstr(ast.arg.val)}, {gstr(ast.arg.dtype)})" if ast.op == BufferOps.CONST else f"{gstr(fxn_for_op[ast.op], ast.op)}(inputs[{ast.arg.idx-1}])"
|
||||
for mop,arg in ast.arg.st.to_movement_ops(): tmp = f"{gstr(fxn_for_op[mop], mop)}({tmp}, {gstr(arg)})"
|
||||
else:
|
||||
inp = [_interpret_ast(src) for src in ast.src]
|
||||
tmp = f"{gstr(fxn_for_op[ast.op], ast.op)}({', '.join(inp + ([gstr(ast.arg)] if ast.arg else []))})"
|
||||
|
||||
ret = f"a{len(lines)}"
|
||||
lines.append(f" {ret} = {tmp}")
|
||||
return ret
|
||||
|
||||
ret = _interpret_ast(ast)
|
||||
src = '\n'.join(['def run(inputs, var_vals):'] + lines + [f" return {gstr(from_underlying, 'from_underlying')}({ret})" if from_underlying is not None else f" return {ret}"])
|
||||
if DEBUG >= 4: print(functools.reduce(lambda x,y: (x.replace(y[0], str(y[1])) if y[0][0:2] == "m0" else x), tglob.items(), src))
|
||||
exec(compile(src, "<ast>", "exec"), tglob) # pylint: disable=exec-used
|
||||
return tglob['run']
|
||||
@@ -1,9 +1,10 @@
|
||||
import numpy as np
|
||||
import operator
|
||||
import operator, functools
|
||||
from typing import Callable, Dict, Tuple, Optional
|
||||
from tinygrad.helpers import dtypes, DType
|
||||
from tinygrad.ops import BufferOps, UnaryOps, BinaryOps, MovementOps, ReduceOps, TernaryOps, Op, Interpreted
|
||||
from tinygrad.runtime.lib import RawBuffer
|
||||
from tinygrad.runtime.interpreted import interpret_ast
|
||||
|
||||
def shape_to_axis(old_shape:Tuple[int, ...], new_shape:Tuple[int, ...]) -> Tuple[int, ...]:
|
||||
assert len(old_shape) == len(new_shape), "reduce shapes must have same dimensions"
|
||||
@@ -51,4 +52,4 @@ class RawNumpyBuffer(RawBuffer):
|
||||
@classmethod
|
||||
def fromCPU(cls, x): return cls(x.size, dtypes.from_np(x.dtype), x)
|
||||
def toCPU(self): return self._buf
|
||||
CPUBuffer = Interpreted(RawNumpyBuffer, numpy_fxn_for_op, from_underlying=RawNumpyBuffer.fromCPU)
|
||||
CPUBuffer = Interpreted(RawNumpyBuffer, functools.partial(interpret_ast, numpy_fxn_for_op, RawNumpyBuffer.fromCPU))
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
import os, mmap
|
||||
import os, mmap, functools
|
||||
from typing import Optional
|
||||
from typing import Callable, Dict, Tuple
|
||||
from tinygrad.helpers import prod, DType
|
||||
from tinygrad.runtime.lib import RawBufferMapped
|
||||
from tinygrad.runtime.interpreted import interpret_ast
|
||||
from tinygrad.ops import Interpreted, Op, MovementOps, UnaryOps, BufferOps
|
||||
|
||||
class RawDiskBuffer(RawBufferMapped):
|
||||
@@ -38,4 +39,4 @@ class RawDiskBuffer(RawBufferMapped):
|
||||
self._buf[0].readinto(buf)
|
||||
|
||||
disk_fxn_for_op: Dict[Op, Callable] = { BufferOps.MEM: lambda x: x, UnaryOps.NOOP: lambda x: x, UnaryOps.CAST: RawDiskBuffer.cast, MovementOps.AS_STRIDED: RawDiskBuffer.as_strided }
|
||||
DiskBuffer = Interpreted(RawDiskBuffer, disk_fxn_for_op, from_underlying=lambda x:x)
|
||||
DiskBuffer = Interpreted(RawDiskBuffer, functools.partial(interpret_ast, disk_fxn_for_op, None))
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
import os, mmap
|
||||
import os, mmap, functools
|
||||
try: import _posixshmem
|
||||
except Exception: pass
|
||||
from typing import Callable, Dict
|
||||
from tinygrad.helpers import DType, OSX
|
||||
from tinygrad.runtime.lib import RawBufferMapped
|
||||
from tinygrad.ops import Interpreted, Op, UnaryOps, MovementOps, BufferOps
|
||||
from tinygrad.runtime.interpreted import interpret_ast
|
||||
|
||||
class RawShmBuffer(RawBufferMapped):
|
||||
def __init__(self, size, dtype:DType, device:str):
|
||||
@@ -25,4 +26,4 @@ class RawShmBuffer(RawBufferMapped):
|
||||
|
||||
# TODO: is this wrong?
|
||||
shm_fxn_for_op: Dict[Op, Callable] = { BufferOps.MEM: lambda x: x, UnaryOps.NOOP: lambda x:x, MovementOps.RESHAPE: lambda x,_:x, MovementOps.AS_STRIDED: lambda x,_:x }
|
||||
ShmBuffer = Interpreted(RawShmBuffer, shm_fxn_for_op, from_underlying=lambda x:x)
|
||||
ShmBuffer = Interpreted(RawShmBuffer, functools.partial(interpret_ast, shm_fxn_for_op, None))
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
import torch
|
||||
import torch, functools
|
||||
import numpy as np
|
||||
from typing import Dict, Callable, Optional
|
||||
from tinygrad.ops import BufferOps, UnaryOps, BinaryOps, MovementOps, TernaryOps, Op, Interpreted
|
||||
from tinygrad.helpers import getenv, dtypes, prod, DType
|
||||
from tinygrad.runtime.ops_cpu import base_fxn_for_op, einsum_mulacc
|
||||
from tinygrad.runtime.lib import RawBuffer
|
||||
from tinygrad.runtime.interpreted import interpret_ast
|
||||
|
||||
device = torch.device("cuda:0" if torch.cuda.is_available() else ("mps" if getenv("MPS", 0) else "cpu"))
|
||||
type_map = {torch.float64: dtypes.float64, torch.float16: dtypes.float16, torch.float32: dtypes.float32, torch.int8: dtypes.int8, torch.int32: dtypes.int32, torch.int64: dtypes.int64, torch.uint8: dtypes.uint8, torch.bool: dtypes.bool, torch.int16: dtypes.int16}
|
||||
@@ -48,4 +49,4 @@ class RawTorchBuffer(RawBuffer):
|
||||
buf = torch.from_numpy(x if all(s>=0 for s in x.strides) else x.copy()).requires_grad_(False).to(device)
|
||||
return cls(prod(x.shape), type_map[buf.dtype], buf)
|
||||
def toCPU(self): return self._buf.cpu().numpy()
|
||||
TorchBuffer = Interpreted(RawTorchBuffer, torch_fxn_for_op, from_underlying=lambda x: RawTorchBuffer(prod(x.shape), type_map[x.dtype], x))
|
||||
TorchBuffer = Interpreted(RawTorchBuffer, functools.partial(interpret_ast, torch_fxn_for_op, lambda x: RawTorchBuffer(prod(x.shape), type_map[x.dtype], x)))
|
||||
|
||||
Reference in New Issue
Block a user