define var can be removed from vars to keep (#3549)

* define var can be removed

* sint

* oops, didn't store
This commit is contained in:
George Hotz
2024-02-29 17:44:19 -08:00
committed by GitHub
parent 2c19ab6561
commit bd9c2ced07
5 changed files with 29 additions and 26 deletions

View File

@@ -21,7 +21,7 @@ def atan2_gpu(ret:Buffer, a:Buffer, b:Buffer):
int idx = get_global_id(0);
c[idx] = atan2(a[idx], b[idx]);
}"""
CompiledASTRunner(None, "atan2_gpu", src, Device[ret.device], global_size=[ret.size]).exec([ret, a, b])
CompiledASTRunner("atan2_gpu", src, Device[ret.device], global_size=[ret.size]).exec([ret, a, b])
def atan2_cpu(ret:Buffer, a:Buffer, b:Buffer): ret.copyin(np.require(np.arctan2(a._buf, b._buf), requirements='C').data)

View File

@@ -13,7 +13,7 @@ from test.test_dtype import is_dtype_supported
def _uops_to_prg(uops):
src = Device[Device.DEFAULT].compiler.render("test", uops)
has_local = Device[Device.DEFAULT].compiler.linearizer_opts.has_local
return CompiledASTRunner(None, "test", src, Device[Device.DEFAULT], [1] if has_local else None, [1] if has_local else None)
return CompiledASTRunner("test", src, Device[Device.DEFAULT], [1] if has_local else None, [1] if has_local else None)
def uop(uops:List[UOp], uop:UOps, dtype:Optional[DType], vin:Tuple[UOp, ...], arg:Any=None) -> UOp:
uops.append(UOp(uop, dtype, tuple(vin), arg))

View File

@@ -33,9 +33,8 @@ def get_recursive_children(uops:List[UOp], x:UOp) -> Set[UOp]:
deps.add(u)
return deps
UOPS_W_SIDE_EFFECTS = {UOps.STORE, UOps.BARRIER, UOps.DEFINE_VAR}
UOPS_W_SIDE_EFFECTS = {UOps.STORE, UOps.BARRIER}
def remove_childless_uops(uops:List[UOp]) -> List[UOp]:
# NOTE: DEFINE_GLOBAL should be removable, but we'd have to propagate that
while 1:
has_child: Set[UOp] = set()
for ru in uops:

View File

@@ -3,6 +3,7 @@ from collections import defaultdict
from typing import TYPE_CHECKING, Any, List, Optional, Dict, Tuple, ClassVar
import importlib, inspect, functools, pathlib, time, ctypes
from tinygrad.dtype import DType, ImageDType
from tinygrad.codegen.uops import UOps
from tinygrad.helpers import ansilen, DEBUG, getenv, colored, BEAM, NOOPT, all_int, to_function_name, from_mv, flat_mv, diskcache_get, diskcache_put
from tinygrad.helpers import prod
from tinygrad.shape.symbolic import Variable, sym_infer, sint
@@ -40,7 +41,9 @@ Device = _Device()
# **************** base Runner + helpers ****************
class JITRunner:
def __init__(self): self.op_estimate, self.mem_estimate = 0, 0
def __init__(self):
self.op_estimate:sint = 0
self.mem_estimate:sint = 0
def exec(self, rawbufs:List[Buffer], var_vals:Optional[Dict[Variable, int]]=None) -> Optional[float]:
var_vals = var_vals if var_vals is not None else {}
from tinygrad.features.jit import CacheCollector
@@ -185,7 +188,8 @@ class Compiler:
return lib
class CompiledASTRunner(JITRunner):
def __init__(self, ast:Optional[LazyOp], name:str, prg:str, device:Compiled, global_size:Optional[List[int]]=None, local_size:Optional[List[int]]=None, precompiled:Optional[bytes]=None): # noqa: E501
def __init__(self, name:str, prg:str, device:Compiled, global_size:Optional[List[int]]=None, local_size:Optional[List[int]]=None,
variables:Optional[List[Variable]]=None, op_estimate:sint=0, mem_estimate:sint=0, precompiled:Optional[bytes]=None):
super().__init__()
if DEBUG >= 4: print(prg)
if global_size is not None: global_size = global_size + [1]*(3-len(global_size))
@@ -195,12 +199,8 @@ class CompiledASTRunner(JITRunner):
assert self.device.compiler is not None, "compiler is reuired to make an AST kernel"
lib:bytes = precompiled if precompiled is not None else self.device.compiler.compile_cached(prg)
self.lib, self.clprg = lib, self.device.runtime(self.name, lib)
self.vars: List[Variable] = []
if ast:
info = get_lazyop_info(ast)
self.op_estimate, self.mem_estimate = info.flops, info.mem_estimate
self.vars = ast.vars()
assert all(v._val is None for v in self.vars), f"ASTRunner contains bound Variable {self.vars}"
self.vars: List[Variable] = [] if variables is None else variables
self.op_estimate, self.mem_estimate = op_estimate, mem_estimate
def launch_dims(self, var_vals):
global_size = [sym_infer(sz, var_vals) for sz in self.global_size] if self.global_size is not None else self.global_size
@@ -235,13 +235,14 @@ class Compiled:
def to_program(self, k:Linearizer) -> CompiledASTRunner:
assert self.compiler is not None, "compiler is required to run AST"
k.linearize()
ret = CompiledASTRunner(k.ast, k.name, self.compiler.render(to_function_name(k.name), k.uops), self, k.global_size, k.local_size)
info = get_lazyop_info(k.ast)
from tinygrad.codegen.uops import uops_flops_mem
run_count = prod((k.global_size if k.global_size else []) + (k.local_size if k.local_size else []))
ops, mem = uops_flops_mem(k.uops)
run_count = prod((k.global_size if k.global_size else []) + (k.local_size if k.local_size else []))
# NOTE: we use min here to ignore the indexing FLOPS
ret.op_estimate = min(ret.op_estimate, ops * run_count)
ret.mem_estimate = min(ret.mem_estimate, mem * run_count)
ret = CompiledASTRunner(k.name, self.compiler.render(to_function_name(k.name), k.uops), self, k.global_size, k.local_size,
[x.arg for x in k.uops if x.uop is UOps.DEFINE_VAR],
min(info.flops, ops * run_count), min(info.mem_estimate, mem * run_count))
return ret
def get_linearizer(self, ast:LazyOp) -> Linearizer:

View File

@@ -1,13 +1,14 @@
from typing import Dict, List, cast, DefaultDict, Optional, Tuple, Callable
import itertools, functools, random, math, time, multiprocessing, traceback, signal
from tinygrad.device import Device, Compiled, Buffer, CompiledASTRunner, Compiler
from tinygrad.ops import MemBuffer, LazyOp
from tinygrad.ops import MemBuffer
from tinygrad.codegen.uops import UOps
from tinygrad.helpers import prod, flatten, DEBUG, CACHELEVEL, diskcache_get, diskcache_put, getenv, Context, colored, to_function_name
from tinygrad.dtype import ImageDType
from tinygrad.codegen.linearizer import Linearizer
from collections import defaultdict
from tinygrad.tensor import Tensor
from tinygrad.shape.symbolic import sym_infer
from tinygrad.shape.symbolic import sym_infer, Variable
from tinygrad.codegen.kernel import Opt, OptOps
actions = [Opt(op=OptOps.UPCAST, axis=axis, amt=amt) for amt in [0,2,3,4,5,7] for axis in range(6)]
@@ -30,11 +31,12 @@ def _get_test_global_size(global_size, max_global_size, var_vals):
break
return test_global_size, factor
def _time_program(ast:LazyOp, rdev:Compiled, lib:bytes, global_size, local_size, var_vals, rawbufs, early_stop=None, max_global_size=65536, clear_l2=False, cnt=3, name="test"): # noqa: E501
def _time_program(variables:List[Variable], rdev:Compiled, lib:bytes, global_size, local_size, var_vals, rawbufs,
early_stop=None, max_global_size=65536, clear_l2=False, cnt=3, name="test"):
factor = 1
if global_size is not None and max_global_size is not None:
global_size, factor = _get_test_global_size(global_size, max_global_size, var_vals)
try: car = CompiledASTRunner(ast, name, "", rdev, global_size, local_size, precompiled=lib)
try: car = CompiledASTRunner(name, "", rdev, global_size, local_size, variables=variables, precompiled=lib)
except AssertionError: return [math.inf] * cnt
tms = []
for _ in range(cnt):
@@ -44,10 +46,11 @@ def _time_program(ast:LazyOp, rdev:Compiled, lib:bytes, global_size, local_size,
if early_stop is not None and early_stop < tms[-1]: break
return tms
def _compile_linearizer(compiler:Compiler, lin:Linearizer, name:Optional[str]=None) -> Tuple[bytes, Optional[List[int]], Optional[List[int]]]:
def _compile_linearizer(compiler:Compiler, lin:Linearizer, name:Optional[str]=None) -> Tuple[bytes, Optional[List[int]], Optional[List[int]],
List[Variable]]:
lin.linearize()
src = compiler.render(name if name is not None else to_function_name(lin.name), lin.uops) # NOTE: these all have the same name for deduping
return compiler.compile(src), lin.global_size, lin.local_size
return compiler.compile(src), lin.global_size, lin.local_size, [x.arg for x in lin.uops if x.uop is UOps.DEFINE_VAR]
def _try_compile_linearized_w_idx(x, compiler:Compiler):
try: return (x[0], _compile_linearizer(compiler, x[1], "test"))
@@ -116,10 +119,10 @@ def beam_search(lin:Linearizer, rawbufs, amt:int, allow_test_size=True) -> Linea
_compile_fn = functools.partial(_try_compile_linearized_w_idx, compiler=Device[lin.opts.device].compiler)
for i,proc in (map(_compile_fn, enumerate(acted_lins)) if beam_pool is None else beam_pool.imap_unordered(_compile_fn, enumerate(acted_lins))):
if proc is None: continue
lib, global_size, local_size = proc
lib, global_size, local_size, vars = proc
if lib in seen_libs: continue
seen_libs.add(lib)
tms = _time_program(lin.ast, dev, lib, global_size, local_size, var_vals, rawbufs, early_stop=beam[0][1]*3 if len(beam) else 1.0)
tms = _time_program(vars, dev, lib, global_size, local_size, var_vals, rawbufs, early_stop=beam[0][1]*3 if len(beam) else 1.0)
timed_lins.append((acted_lins[i], min(tms)))
if DEBUG >= 2: print(f"\r{time.perf_counter() - st:7.2f}s: {timed_lins[-1][1]*1e6:12.2f} us {len(timed_lins):4d}/{len(acted_lins):4d} {timed_lins[-1][0].colored_shape()}\033[K", end="") # noqa: E501
@@ -157,8 +160,8 @@ def time_linearizer(lin:Linearizer, rawbufs:List[Buffer], allow_test_size=True,
assert isinstance(dev, Compiled) and dev.compiler is not None
var_vals = {k:(k.max+k.min)//2 for k in lin.ast.vars()}
lib, global_size, local_size = _compile_linearizer(dev.compiler, lin)
tms = _time_program(lin.ast, dev, lib, global_size, local_size, var_vals, rawbufs, max_global_size=max_global_size if allow_test_size else None, clear_l2=clear_l2, cnt=cnt, name=to_function_name(lin.name)) # noqa: E501
lib, global_size, local_size, vars = _compile_linearizer(dev.compiler, lin)
tms = _time_program(vars, dev, lib, global_size, local_size, var_vals, rawbufs, max_global_size=max_global_size if allow_test_size else None, clear_l2=clear_l2, cnt=cnt, name=to_function_name(lin.name)) # noqa: E501
if CACHELEVEL >= 2: diskcache_put("time_linearizer", key, tms)
return min(tms)