Interpreted cleanups (#2312)

* move the compiler out of ops

* don't return realized

* var_vals filter, fix custom

* typing
This commit is contained in:
George Hotz
2023-11-15 09:02:23 -08:00
committed by GitHub
parent 123a0b86b2
commit 4da2ddea6e
9 changed files with 85 additions and 86 deletions

View File

@@ -45,10 +45,8 @@ OPT | [1-3] | optimization level
BEAM | [#] | number of beams in kernel beam search
GRAPH | [1] | create a graph of all operations (requires graphviz)
GRAPHPATH | [/path/to] | where to put the generated graph
PRINT_PRG | [1] | print program code
IMAGE | [1] | enable 2d specific optimizations
FLOAT16 | [1] | use float16 for images instead of float32
ENABLE_METHOD_CACHE | [1] | enable method cache (this is the default)
DISALLOW_ASSIGN | [1] | disallow assignment of tensors
CL_EXCLUDE | [name0,name1] | comma-separated list of device names to exclude when using OpenCL GPU backend (like `CL_EXCLUDE=gfx1036`)
CL_PLATFORM | [# >= 0] | index of the OpenCL [platform](https://documen.tician.de/pyopencl/runtime_platform.html#pyopencl.Platform) to run on. Defaults to 0.

View File

@@ -28,7 +28,7 @@ def atan2_gpu(ret:LazyBuffer, a:LazyBuffer, b:LazyBuffer):
return ret.realized
def atan2_cpu(ret:LazyBuffer, a:LazyBuffer, b:LazyBuffer):
return Device[ret.device].from_underlying(np.arctan2(a.realized._buf, b.realized._buf))
return Device[ret.device].buffer.fromCPU(np.arctan2(a.realized._buf, b.realized._buf))
# *** second, we write the ATan2 mlop ***
# NOTE: The derivative of atan2 doesn't need a custom op! https://www.liquisearch.com/atan2/derivative

View File

@@ -1,5 +1,5 @@
from __future__ import annotations
import importlib, inspect, functools, pathlib, re
import importlib, inspect, functools, pathlib
from enum import Enum, auto
from typing import TYPE_CHECKING, Union, Type, Tuple, Any, List, Optional, Dict, Callable, Mapping, cast
from tinygrad.helpers import ansilen, prod, DEBUG, getenv, GlobalCounters, DType, colored, BEAM, NOOPT, dedup, all_int
@@ -128,7 +128,7 @@ class BatchExecutor:
def __call__(self, input_rawbuffers: Dict[Union[int, str], RawBuffer], var_vals: Dict[Variable, int], wait=False):
for (j,i),input_name in self.input_replace.items(): self.jit_cache[j].rawbufs[i] = input_rawbuffers[input_name]
for ji in self.jit_cache: ji.prg(cast(List[RawBuffer], ji.rawbufs), {v:var_vals[v] for v in getattr(ji.prg,"vars",[])}, jit=True)
for ji in self.jit_cache: ji.prg(cast(List[RawBuffer], ji.rawbufs), var_vals, jit=True)
self.clear_jit_inputs()
def update_stats(self, var_vals: Dict[Variable, int], et: Optional[float]):
@@ -146,66 +146,6 @@ class BatchExecutor:
def clear_jit_inputs(self):
for (j,i) in self.input_replace.keys(): self.jit_cache[j].rawbufs[i] = None
# **************** for Interpreted Buffers ****************
class Interpreted:
def __init__(self, buffer, fxn_for_op: Dict[Op, Callable], from_underlying=None):
self.buffer, self.fxn_for_op, self.from_underlying = buffer, fxn_for_op, from_underlying
self.synchronize = lambda: None
self.batch_executor = BatchExecutor
self.codegen = None
self.method_cache: Dict[LazyOp, Callable] = {}
def interpret_ast(self:Interpreted, ast:LazyOp) -> Callable:
if DEBUG >= 3:
from tinygrad.graph import print_tree
print_tree(ast)
tglob: Dict[str, Any] = {"Variable": Variable}
lines: List[str] = []
f = self.fxn_for_op
@functools.lru_cache(None)
def gstr(x:Any, nm=None) -> str:
if ('Variable' in (str_arg := repr(x)) or 'NumNode' in str_arg):
str_arg = re.sub(r'Variable\(.*?\)', lambda m: f'var_vals[{str(m.group(0))}]', str_arg)
# TODO: (Variable - Variable) might create NumNode. can we remove it?
return re.sub(r'NumNode\((.*?)\)', r'\1', str_arg)
ret = str(nm).replace(".", "_") if nm else f"m{len(tglob):04d}"
tglob[ret] = x
return ret
@functools.lru_cache(None)
def _interpret_ast(ast:LazyOp) -> str:
if TernaryOps.MULACC in f and ast.op == ReduceOps.SUM and isinstance(ast.src[0], LazyOp) and ast.src[0].op == BinaryOps.MUL:
ast = LazyOp(TernaryOps.MULACC, ast.src[0].src, ast.arg)
if MovementOps.AS_STRIDED in f and ast.op in BufferOps:
tmp = f"{gstr(f[ast.op], ast.op)}({gstr(ast.arg.val)}, {gstr(ast.arg.dtype)})" if ast.op == BufferOps.CONST else f"{gstr(f[ast.op], ast.op)}(inputs[{ast.arg.idx-1}])"
for mop,arg in ast.arg.st.to_movement_ops(): tmp = f"{gstr(f[mop], mop)}({tmp}, {gstr(arg)})"
else:
inp = [_interpret_ast(src) for src in ast.src]
tmp = f"{gstr(f[ast.op], ast.op)}({', '.join(inp + ([gstr(ast.arg)] if ast.arg else []))})"
ret = f"a{len(lines)}"
lines.append(f" {ret} = {tmp}")
return ret
ret = _interpret_ast(ast)
src = '\n'.join(['def run(inputs, var_vals):'] + lines + [f" return {gstr(self.from_underlying, 'from_underlying')}({ret})"])
if DEBUG >= 4: print(functools.reduce(lambda x,y: (x.replace(y[0], str(y[1])) if y[0][0:2] == "m0" else x), tglob.items(), src))
exec(compile(src, "<ast>", "exec"), tglob) # pylint: disable=exec-used
return tglob['run']
def exec_ast(self, ast:LazyOp, output:LazyBuffer, inputs:Tuple[LazyBuffer, ...], var_vals:Dict[Variable, int], **kwargs):
if ast not in self.method_cache: self.method_cache[ast] = self.interpret_ast(ast)
ret = self.method_cache[ast]([x.realized for x in inputs] if inputs else None, var_vals)
assert ret.dtype == output.dtype, f"{ret.dtype} != {output.dtype}"
if output.output_buffer is not None:
assert output.output_buffer.dtype == ret.dtype
output.output_buffer._buf = ret._buf
return output.output_buffer
return ret
# **************** independent FlopCounter ****************
@dataclass
@@ -234,6 +174,24 @@ def get_lazyop_info(ast:LazyOp) -> FlopCounter:
def run_ast(ast): return InterpretedFlopCounter[ast.op](*([run_ast(x) for x in ast.src]+([ast.arg] if ast.arg is not None else [])))
return run_ast(ast)
# **************** for Interpreted Buffers ****************
class Interpreted:
def __init__(self, buffer: Type[RawBuffer], compiler: Callable[[LazyOp], Callable]):
self.buffer, self.compiler = buffer, compiler
self.synchronize = lambda: None
self.batch_executor = BatchExecutor
self.codegen = None
self.method_cache: Dict[LazyOp, Callable] = {}
def exec_ast(self, ast:LazyOp, output:LazyBuffer, inputs:Tuple[LazyBuffer, ...], var_vals:Dict[Variable, int], **kwargs):
if ast not in self.method_cache: self.method_cache[ast] = self.compiler(ast)
output.realized = output.output_buffer # NOTE: assuming this is the right size and dtype from assign
ret: RawBuffer = self.method_cache[ast]([x.realized for x in inputs] if inputs else None, var_vals)
assert output.dtype == ret.dtype, f"expected {output.dtype}, got {ret.dtype}"
if output.realized is not None: output.realized._buf = ret._buf
else: output.realized = ret
# **************** for Compiled Buffers ****************
class ASTRunner:
@@ -247,7 +205,7 @@ class ASTRunner:
self.clprg = runtime(self.name, self.lib)
return self
def exec(self, rawbufs, var_vals:Optional[Dict[Variable, int]]=None, force_wait=False) -> Optional[float]:
def exec(self, rawbufs:List[RawBuffer], var_vals:Optional[Dict[Variable, int]]=None, force_wait=False) -> Optional[float]:
from tinygrad.jit import CacheCollector
CacheCollector.add(self, rawbufs, var_vals if var_vals is not None else {})
return self(rawbufs, var_vals, force_wait=force_wait)
@@ -259,6 +217,7 @@ class ASTRunner:
def __call__(self, rawbufs:List[RawBuffer], var_vals:Optional[Dict[Variable, int]]=None, jit=False, force_wait=False) -> Optional[float]:
if var_vals is None: var_vals = {}
var_vals = {k:var_vals[k] for k in self.vars} # filter the var_vals
global_size, local_size = self.launch_dims(var_vals)
if global_size is not None and local_size is None and all_int(self.global_size): # type: ignore[arg-type]
# TODO: this is copied from get_program
@@ -327,7 +286,8 @@ class Compiled:
def exec_ast(self, ast:LazyOp, output:LazyBuffer, inputs:Tuple[LazyBuffer, ...], var_vals:Dict[Variable, int], **kwargs):
# check if we can reuse the output buffer
# if it's aliased, don't use it
# NOTE: this is pretty wrong actually, who knows where else this buffer is used?
# TODO: this is pretty wrong actually, who knows where else this buffer is used?
# TODO: what if an assign is required? this silently is wrong
output.realized = output.output_buffer
if output.realized is not None:
for i,a in enumerate(inputs):
@@ -345,13 +305,5 @@ class Compiled:
# all the rawbuffers
rawbuffers = [output.realized] + [x.realized for x in inputs]
if getenv("ENABLE_METHOD_CACHE", 1):
if ast not in self.method_cache: self.method_cache[ast] = self.get_optimized_program(ast, rawbuffers)
prg = self.method_cache[ast]
else:
prg = self.get_optimized_program(ast, rawbuffers)
if prg.name == getenv("PRINT_PRG", ''): print(prg.prg)
prg.exec(rawbuffers, var_vals={k:var_vals[k] for k in prg.vars})
return output.realized
if ast not in self.method_cache: self.method_cache[ast] = self.get_optimized_program(ast, rawbuffers)
self.method_cache[ast].exec(rawbuffers, var_vals)

View File

@@ -23,7 +23,7 @@ def run_schedule(schedule:List[ScheduleItem], disable_logging=False):
for i,s in enumerate(si.ast.src): assert isinstance(s, LazyOp) and s.op == BufferOps.MEM and s.arg.idx == i+1 and s.arg.st.contiguous, f"bad LoadOps src {i}: {s}"
LOAD_OPS_DISPATCHER[cast(LoadOps, si.ast.op)](si.out, *si.inputs)
else:
si.out.realized = Device[si.out.device].exec_ast(si.ast, output=si.out, inputs=si.inputs, var_vals=si.var_vals, **si.out._device_extra_args())
Device[si.out.device].exec_ast(si.ast, output=si.out, inputs=si.inputs, var_vals=si.var_vals, **si.out._device_extra_args())
del si.out.op
for v in si.out.views: del v.op
assert si.out.realized and isinstance(si.out.realized, Device[si.out.device].buffer), f"device mismatch on realized got {type(si.out.realized)} expected {si.out.device}"

View File

@@ -0,0 +1,45 @@
from typing import Callable, Optional, Dict, List, Any
import functools, re
from tinygrad.helpers import DEBUG
from tinygrad.ops import LazyOp, TernaryOps, ReduceOps, BinaryOps, BufferOps, Op
from tinygrad.shape.symbolic import Variable
from tinygrad.runtime.lib import RawBuffer
def interpret_ast(fxn_for_op:Dict[Op, Callable], from_underlying:Optional[Callable], ast:LazyOp) -> Callable[[List[RawBuffer], Dict[Variable, int]], RawBuffer]:
if DEBUG >= 3:
from tinygrad.graph import print_tree
print_tree(ast)
tglob: Dict[str, Any] = {"Variable": Variable}
lines: List[str] = []
@functools.lru_cache(None)
def gstr(x:Any, nm=None) -> str:
if ('Variable' in (str_arg := repr(x)) or 'NumNode' in str_arg):
str_arg = re.sub(r'Variable\(.*?\)', lambda m: f'var_vals[{str(m.group(0))}]', str_arg)
# TODO: (Variable - Variable) might create NumNode. can we remove it?
return re.sub(r'NumNode\((.*?)\)', r'\1', str_arg)
ret = str(nm).replace(".", "_") if nm else f"m{len(tglob):04d}"
tglob[ret] = x
return ret
@functools.lru_cache(None)
def _interpret_ast(ast:LazyOp) -> str:
if TernaryOps.MULACC in fxn_for_op and ast.op == ReduceOps.SUM and isinstance(ast.src[0], LazyOp) and ast.src[0].op == BinaryOps.MUL:
ast = LazyOp(TernaryOps.MULACC, ast.src[0].src, ast.arg)
if ast.op in BufferOps:
tmp = f"{gstr(fxn_for_op[ast.op], ast.op)}({gstr(ast.arg.val)}, {gstr(ast.arg.dtype)})" if ast.op == BufferOps.CONST else f"{gstr(fxn_for_op[ast.op], ast.op)}(inputs[{ast.arg.idx-1}])"
for mop,arg in ast.arg.st.to_movement_ops(): tmp = f"{gstr(fxn_for_op[mop], mop)}({tmp}, {gstr(arg)})"
else:
inp = [_interpret_ast(src) for src in ast.src]
tmp = f"{gstr(fxn_for_op[ast.op], ast.op)}({', '.join(inp + ([gstr(ast.arg)] if ast.arg else []))})"
ret = f"a{len(lines)}"
lines.append(f" {ret} = {tmp}")
return ret
ret = _interpret_ast(ast)
src = '\n'.join(['def run(inputs, var_vals):'] + lines + [f" return {gstr(from_underlying, 'from_underlying')}({ret})" if from_underlying is not None else f" return {ret}"])
if DEBUG >= 4: print(functools.reduce(lambda x,y: (x.replace(y[0], str(y[1])) if y[0][0:2] == "m0" else x), tglob.items(), src))
exec(compile(src, "<ast>", "exec"), tglob) # pylint: disable=exec-used
return tglob['run']

View File

@@ -1,9 +1,10 @@
import numpy as np
import operator
import operator, functools
from typing import Callable, Dict, Tuple, Optional
from tinygrad.helpers import dtypes, DType
from tinygrad.ops import BufferOps, UnaryOps, BinaryOps, MovementOps, ReduceOps, TernaryOps, Op, Interpreted
from tinygrad.runtime.lib import RawBuffer
from tinygrad.runtime.interpreted import interpret_ast
def shape_to_axis(old_shape:Tuple[int, ...], new_shape:Tuple[int, ...]) -> Tuple[int, ...]:
assert len(old_shape) == len(new_shape), "reduce shapes must have same dimensions"
@@ -51,4 +52,4 @@ class RawNumpyBuffer(RawBuffer):
@classmethod
def fromCPU(cls, x): return cls(x.size, dtypes.from_np(x.dtype), x)
def toCPU(self): return self._buf
CPUBuffer = Interpreted(RawNumpyBuffer, numpy_fxn_for_op, from_underlying=RawNumpyBuffer.fromCPU)
CPUBuffer = Interpreted(RawNumpyBuffer, functools.partial(interpret_ast, numpy_fxn_for_op, RawNumpyBuffer.fromCPU))

View File

@@ -1,8 +1,9 @@
import os, mmap
import os, mmap, functools
from typing import Optional
from typing import Callable, Dict, Tuple
from tinygrad.helpers import prod, DType
from tinygrad.runtime.lib import RawBufferMapped
from tinygrad.runtime.interpreted import interpret_ast
from tinygrad.ops import Interpreted, Op, MovementOps, UnaryOps, BufferOps
class RawDiskBuffer(RawBufferMapped):
@@ -38,4 +39,4 @@ class RawDiskBuffer(RawBufferMapped):
self._buf[0].readinto(buf)
disk_fxn_for_op: Dict[Op, Callable] = { BufferOps.MEM: lambda x: x, UnaryOps.NOOP: lambda x: x, UnaryOps.CAST: RawDiskBuffer.cast, MovementOps.AS_STRIDED: RawDiskBuffer.as_strided }
DiskBuffer = Interpreted(RawDiskBuffer, disk_fxn_for_op, from_underlying=lambda x:x)
DiskBuffer = Interpreted(RawDiskBuffer, functools.partial(interpret_ast, disk_fxn_for_op, None))

View File

@@ -1,10 +1,11 @@
import os, mmap
import os, mmap, functools
try: import _posixshmem
except Exception: pass
from typing import Callable, Dict
from tinygrad.helpers import DType, OSX
from tinygrad.runtime.lib import RawBufferMapped
from tinygrad.ops import Interpreted, Op, UnaryOps, MovementOps, BufferOps
from tinygrad.runtime.interpreted import interpret_ast
class RawShmBuffer(RawBufferMapped):
def __init__(self, size, dtype:DType, device:str):
@@ -25,4 +26,4 @@ class RawShmBuffer(RawBufferMapped):
# TODO: is this wrong?
shm_fxn_for_op: Dict[Op, Callable] = { BufferOps.MEM: lambda x: x, UnaryOps.NOOP: lambda x:x, MovementOps.RESHAPE: lambda x,_:x, MovementOps.AS_STRIDED: lambda x,_:x }
ShmBuffer = Interpreted(RawShmBuffer, shm_fxn_for_op, from_underlying=lambda x:x)
ShmBuffer = Interpreted(RawShmBuffer, functools.partial(interpret_ast, shm_fxn_for_op, None))

View File

@@ -1,10 +1,11 @@
import torch
import torch, functools
import numpy as np
from typing import Dict, Callable, Optional
from tinygrad.ops import BufferOps, UnaryOps, BinaryOps, MovementOps, TernaryOps, Op, Interpreted
from tinygrad.helpers import getenv, dtypes, prod, DType
from tinygrad.runtime.ops_cpu import base_fxn_for_op, einsum_mulacc
from tinygrad.runtime.lib import RawBuffer
from tinygrad.runtime.interpreted import interpret_ast
device = torch.device("cuda:0" if torch.cuda.is_available() else ("mps" if getenv("MPS", 0) else "cpu"))
type_map = {torch.float64: dtypes.float64, torch.float16: dtypes.float16, torch.float32: dtypes.float32, torch.int8: dtypes.int8, torch.int32: dtypes.int32, torch.int64: dtypes.int64, torch.uint8: dtypes.uint8, torch.bool: dtypes.bool, torch.int16: dtypes.int16}
@@ -48,4 +49,4 @@ class RawTorchBuffer(RawBuffer):
buf = torch.from_numpy(x if all(s>=0 for s in x.strides) else x.copy()).requires_grad_(False).to(device)
return cls(prod(x.shape), type_map[buf.dtype], buf)
def toCPU(self): return self._buf.cpu().numpy()
TorchBuffer = Interpreted(RawTorchBuffer, torch_fxn_for_op, from_underlying=lambda x: RawTorchBuffer(prod(x.shape), type_map[x.dtype], x))
TorchBuffer = Interpreted(RawTorchBuffer, functools.partial(interpret_ast, torch_fxn_for_op, lambda x: RawTorchBuffer(prod(x.shape), type_map[x.dtype], x)))