mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 07:28:15 -05:00
define var can be removed from vars to keep (#3549)
* define var can be removed * sint * oops, didn't store
This commit is contained in:
@@ -21,7 +21,7 @@ def atan2_gpu(ret:Buffer, a:Buffer, b:Buffer):
|
||||
int idx = get_global_id(0);
|
||||
c[idx] = atan2(a[idx], b[idx]);
|
||||
}"""
|
||||
CompiledASTRunner(None, "atan2_gpu", src, Device[ret.device], global_size=[ret.size]).exec([ret, a, b])
|
||||
CompiledASTRunner("atan2_gpu", src, Device[ret.device], global_size=[ret.size]).exec([ret, a, b])
|
||||
|
||||
def atan2_cpu(ret:Buffer, a:Buffer, b:Buffer): ret.copyin(np.require(np.arctan2(a._buf, b._buf), requirements='C').data)
|
||||
|
||||
|
||||
@@ -13,7 +13,7 @@ from test.test_dtype import is_dtype_supported
|
||||
def _uops_to_prg(uops):
|
||||
src = Device[Device.DEFAULT].compiler.render("test", uops)
|
||||
has_local = Device[Device.DEFAULT].compiler.linearizer_opts.has_local
|
||||
return CompiledASTRunner(None, "test", src, Device[Device.DEFAULT], [1] if has_local else None, [1] if has_local else None)
|
||||
return CompiledASTRunner("test", src, Device[Device.DEFAULT], [1] if has_local else None, [1] if has_local else None)
|
||||
|
||||
def uop(uops:List[UOp], uop:UOps, dtype:Optional[DType], vin:Tuple[UOp, ...], arg:Any=None) -> UOp:
|
||||
uops.append(UOp(uop, dtype, tuple(vin), arg))
|
||||
|
||||
@@ -33,9 +33,8 @@ def get_recursive_children(uops:List[UOp], x:UOp) -> Set[UOp]:
|
||||
deps.add(u)
|
||||
return deps
|
||||
|
||||
UOPS_W_SIDE_EFFECTS = {UOps.STORE, UOps.BARRIER, UOps.DEFINE_VAR}
|
||||
UOPS_W_SIDE_EFFECTS = {UOps.STORE, UOps.BARRIER}
|
||||
def remove_childless_uops(uops:List[UOp]) -> List[UOp]:
|
||||
# NOTE: DEFINE_GLOBAL should be removable, but we'd have to propagate that
|
||||
while 1:
|
||||
has_child: Set[UOp] = set()
|
||||
for ru in uops:
|
||||
|
||||
@@ -3,6 +3,7 @@ from collections import defaultdict
|
||||
from typing import TYPE_CHECKING, Any, List, Optional, Dict, Tuple, ClassVar
|
||||
import importlib, inspect, functools, pathlib, time, ctypes
|
||||
from tinygrad.dtype import DType, ImageDType
|
||||
from tinygrad.codegen.uops import UOps
|
||||
from tinygrad.helpers import ansilen, DEBUG, getenv, colored, BEAM, NOOPT, all_int, to_function_name, from_mv, flat_mv, diskcache_get, diskcache_put
|
||||
from tinygrad.helpers import prod
|
||||
from tinygrad.shape.symbolic import Variable, sym_infer, sint
|
||||
@@ -40,7 +41,9 @@ Device = _Device()
|
||||
# **************** base Runner + helpers ****************
|
||||
|
||||
class JITRunner:
|
||||
def __init__(self): self.op_estimate, self.mem_estimate = 0, 0
|
||||
def __init__(self):
|
||||
self.op_estimate:sint = 0
|
||||
self.mem_estimate:sint = 0
|
||||
def exec(self, rawbufs:List[Buffer], var_vals:Optional[Dict[Variable, int]]=None) -> Optional[float]:
|
||||
var_vals = var_vals if var_vals is not None else {}
|
||||
from tinygrad.features.jit import CacheCollector
|
||||
@@ -185,7 +188,8 @@ class Compiler:
|
||||
return lib
|
||||
|
||||
class CompiledASTRunner(JITRunner):
|
||||
def __init__(self, ast:Optional[LazyOp], name:str, prg:str, device:Compiled, global_size:Optional[List[int]]=None, local_size:Optional[List[int]]=None, precompiled:Optional[bytes]=None): # noqa: E501
|
||||
def __init__(self, name:str, prg:str, device:Compiled, global_size:Optional[List[int]]=None, local_size:Optional[List[int]]=None,
|
||||
variables:Optional[List[Variable]]=None, op_estimate:sint=0, mem_estimate:sint=0, precompiled:Optional[bytes]=None):
|
||||
super().__init__()
|
||||
if DEBUG >= 4: print(prg)
|
||||
if global_size is not None: global_size = global_size + [1]*(3-len(global_size))
|
||||
@@ -195,12 +199,8 @@ class CompiledASTRunner(JITRunner):
|
||||
assert self.device.compiler is not None, "compiler is reuired to make an AST kernel"
|
||||
lib:bytes = precompiled if precompiled is not None else self.device.compiler.compile_cached(prg)
|
||||
self.lib, self.clprg = lib, self.device.runtime(self.name, lib)
|
||||
self.vars: List[Variable] = []
|
||||
if ast:
|
||||
info = get_lazyop_info(ast)
|
||||
self.op_estimate, self.mem_estimate = info.flops, info.mem_estimate
|
||||
self.vars = ast.vars()
|
||||
assert all(v._val is None for v in self.vars), f"ASTRunner contains bound Variable {self.vars}"
|
||||
self.vars: List[Variable] = [] if variables is None else variables
|
||||
self.op_estimate, self.mem_estimate = op_estimate, mem_estimate
|
||||
|
||||
def launch_dims(self, var_vals):
|
||||
global_size = [sym_infer(sz, var_vals) for sz in self.global_size] if self.global_size is not None else self.global_size
|
||||
@@ -235,13 +235,14 @@ class Compiled:
|
||||
def to_program(self, k:Linearizer) -> CompiledASTRunner:
|
||||
assert self.compiler is not None, "compiler is required to run AST"
|
||||
k.linearize()
|
||||
ret = CompiledASTRunner(k.ast, k.name, self.compiler.render(to_function_name(k.name), k.uops), self, k.global_size, k.local_size)
|
||||
info = get_lazyop_info(k.ast)
|
||||
from tinygrad.codegen.uops import uops_flops_mem
|
||||
run_count = prod((k.global_size if k.global_size else []) + (k.local_size if k.local_size else []))
|
||||
ops, mem = uops_flops_mem(k.uops)
|
||||
run_count = prod((k.global_size if k.global_size else []) + (k.local_size if k.local_size else []))
|
||||
# NOTE: we use min here to ignore the indexing FLOPS
|
||||
ret.op_estimate = min(ret.op_estimate, ops * run_count)
|
||||
ret.mem_estimate = min(ret.mem_estimate, mem * run_count)
|
||||
ret = CompiledASTRunner(k.name, self.compiler.render(to_function_name(k.name), k.uops), self, k.global_size, k.local_size,
|
||||
[x.arg for x in k.uops if x.uop is UOps.DEFINE_VAR],
|
||||
min(info.flops, ops * run_count), min(info.mem_estimate, mem * run_count))
|
||||
return ret
|
||||
|
||||
def get_linearizer(self, ast:LazyOp) -> Linearizer:
|
||||
|
||||
@@ -1,13 +1,14 @@
|
||||
from typing import Dict, List, cast, DefaultDict, Optional, Tuple, Callable
|
||||
import itertools, functools, random, math, time, multiprocessing, traceback, signal
|
||||
from tinygrad.device import Device, Compiled, Buffer, CompiledASTRunner, Compiler
|
||||
from tinygrad.ops import MemBuffer, LazyOp
|
||||
from tinygrad.ops import MemBuffer
|
||||
from tinygrad.codegen.uops import UOps
|
||||
from tinygrad.helpers import prod, flatten, DEBUG, CACHELEVEL, diskcache_get, diskcache_put, getenv, Context, colored, to_function_name
|
||||
from tinygrad.dtype import ImageDType
|
||||
from tinygrad.codegen.linearizer import Linearizer
|
||||
from collections import defaultdict
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.shape.symbolic import sym_infer
|
||||
from tinygrad.shape.symbolic import sym_infer, Variable
|
||||
|
||||
from tinygrad.codegen.kernel import Opt, OptOps
|
||||
actions = [Opt(op=OptOps.UPCAST, axis=axis, amt=amt) for amt in [0,2,3,4,5,7] for axis in range(6)]
|
||||
@@ -30,11 +31,12 @@ def _get_test_global_size(global_size, max_global_size, var_vals):
|
||||
break
|
||||
return test_global_size, factor
|
||||
|
||||
def _time_program(ast:LazyOp, rdev:Compiled, lib:bytes, global_size, local_size, var_vals, rawbufs, early_stop=None, max_global_size=65536, clear_l2=False, cnt=3, name="test"): # noqa: E501
|
||||
def _time_program(variables:List[Variable], rdev:Compiled, lib:bytes, global_size, local_size, var_vals, rawbufs,
|
||||
early_stop=None, max_global_size=65536, clear_l2=False, cnt=3, name="test"):
|
||||
factor = 1
|
||||
if global_size is not None and max_global_size is not None:
|
||||
global_size, factor = _get_test_global_size(global_size, max_global_size, var_vals)
|
||||
try: car = CompiledASTRunner(ast, name, "", rdev, global_size, local_size, precompiled=lib)
|
||||
try: car = CompiledASTRunner(name, "", rdev, global_size, local_size, variables=variables, precompiled=lib)
|
||||
except AssertionError: return [math.inf] * cnt
|
||||
tms = []
|
||||
for _ in range(cnt):
|
||||
@@ -44,10 +46,11 @@ def _time_program(ast:LazyOp, rdev:Compiled, lib:bytes, global_size, local_size,
|
||||
if early_stop is not None and early_stop < tms[-1]: break
|
||||
return tms
|
||||
|
||||
def _compile_linearizer(compiler:Compiler, lin:Linearizer, name:Optional[str]=None) -> Tuple[bytes, Optional[List[int]], Optional[List[int]]]:
|
||||
def _compile_linearizer(compiler:Compiler, lin:Linearizer, name:Optional[str]=None) -> Tuple[bytes, Optional[List[int]], Optional[List[int]],
|
||||
List[Variable]]:
|
||||
lin.linearize()
|
||||
src = compiler.render(name if name is not None else to_function_name(lin.name), lin.uops) # NOTE: these all have the same name for deduping
|
||||
return compiler.compile(src), lin.global_size, lin.local_size
|
||||
return compiler.compile(src), lin.global_size, lin.local_size, [x.arg for x in lin.uops if x.uop is UOps.DEFINE_VAR]
|
||||
|
||||
def _try_compile_linearized_w_idx(x, compiler:Compiler):
|
||||
try: return (x[0], _compile_linearizer(compiler, x[1], "test"))
|
||||
@@ -116,10 +119,10 @@ def beam_search(lin:Linearizer, rawbufs, amt:int, allow_test_size=True) -> Linea
|
||||
_compile_fn = functools.partial(_try_compile_linearized_w_idx, compiler=Device[lin.opts.device].compiler)
|
||||
for i,proc in (map(_compile_fn, enumerate(acted_lins)) if beam_pool is None else beam_pool.imap_unordered(_compile_fn, enumerate(acted_lins))):
|
||||
if proc is None: continue
|
||||
lib, global_size, local_size = proc
|
||||
lib, global_size, local_size, vars = proc
|
||||
if lib in seen_libs: continue
|
||||
seen_libs.add(lib)
|
||||
tms = _time_program(lin.ast, dev, lib, global_size, local_size, var_vals, rawbufs, early_stop=beam[0][1]*3 if len(beam) else 1.0)
|
||||
tms = _time_program(vars, dev, lib, global_size, local_size, var_vals, rawbufs, early_stop=beam[0][1]*3 if len(beam) else 1.0)
|
||||
timed_lins.append((acted_lins[i], min(tms)))
|
||||
if DEBUG >= 2: print(f"\r{time.perf_counter() - st:7.2f}s: {timed_lins[-1][1]*1e6:12.2f} us {len(timed_lins):4d}/{len(acted_lins):4d} {timed_lins[-1][0].colored_shape()}\033[K", end="") # noqa: E501
|
||||
|
||||
@@ -157,8 +160,8 @@ def time_linearizer(lin:Linearizer, rawbufs:List[Buffer], allow_test_size=True,
|
||||
assert isinstance(dev, Compiled) and dev.compiler is not None
|
||||
|
||||
var_vals = {k:(k.max+k.min)//2 for k in lin.ast.vars()}
|
||||
lib, global_size, local_size = _compile_linearizer(dev.compiler, lin)
|
||||
tms = _time_program(lin.ast, dev, lib, global_size, local_size, var_vals, rawbufs, max_global_size=max_global_size if allow_test_size else None, clear_l2=clear_l2, cnt=cnt, name=to_function_name(lin.name)) # noqa: E501
|
||||
lib, global_size, local_size, vars = _compile_linearizer(dev.compiler, lin)
|
||||
tms = _time_program(vars, dev, lib, global_size, local_size, var_vals, rawbufs, max_global_size=max_global_size if allow_test_size else None, clear_l2=clear_l2, cnt=cnt, name=to_function_name(lin.name)) # noqa: E501
|
||||
|
||||
if CACHELEVEL >= 2: diskcache_put("time_linearizer", key, tms)
|
||||
return min(tms)
|
||||
|
||||
Reference in New Issue
Block a user