mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 07:28:15 -05:00
Fix VALIDHACKS for Images and make it default (#1832)
* valid hacks * valid hacks * valid hacks * new method * new method * handtune * is gate load breaking? * lint ruff less junk new approach? maybe this? * Make it more clear * Make it more clear * Will deal with the linter later * hack for linter * subs the idx but dont touch the valid * Updated the mod rules * lint hack * I believe bug fix lets see * Mod Node left * revert * Maybe this wont break? * revert * implemented "handtuned garbage" * revert and use VALIDHACKS * Lets see the CI * still broken? * currently its jungle * maybe this jungle ? * This works for everything somehow * Added test for symbolic * lint * final touch * This still works * lint * midway clean * less garbage * lint * final form * Slow but working way * lint and other stuff * lint * mypy * Make sure CI test Openpilot valid checks * test if CI break * Convert back * refactor * refactor * Managed to reduce openpilot time from 30 secs to 5 secs * Refactor * Substitute a node with variable * flake8 * Comment and refactor * More comprehensive mod * refactor * bug fix * More shave off * remove not sure part
This commit is contained in:
2
.github/workflows/test.yml
vendored
2
.github/workflows/test.yml
vendored
@@ -147,7 +147,7 @@ jobs:
|
||||
- if: ${{ matrix.task == 'openpilot' }}
|
||||
name: Test openpilot model compile and size
|
||||
run: |
|
||||
DEBUG=2 ALLOWED_KERNEL_COUNT=209 FLOAT16=1 DEBUGCL=1 GPU=1 IMAGE=2 python openpilot/compile.py
|
||||
DEBUG=2 ALLOWED_KERNEL_COUNT=209 VALIDTEST=1 FLOAT16=1 DEBUGCL=1 GPU=1 IMAGE=2 python openpilot/compile.py
|
||||
python -c 'import os; assert os.path.getsize("/tmp/output.thneed") < 100_000_000'
|
||||
- if: ${{ matrix.task == 'openpilot' }}
|
||||
name: Test openpilot model correctness (float32)
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
#!/usr/bin/env python3
|
||||
import os, time, io, pathlib, sys, traceback
|
||||
import os, time, io, pathlib, sys, traceback, re
|
||||
sys.path.insert(0, str(pathlib.Path(__file__).parents[1]))
|
||||
|
||||
if os.getenv("OPT", None) is None:
|
||||
@@ -72,6 +72,11 @@ def compile(dat, output_fn):
|
||||
# pass these to thneed
|
||||
setattr(prg.clprg, 'op_estimate', prg.op_estimate)
|
||||
setattr(prg.clprg, 'prg', prg.prg)
|
||||
|
||||
if getenv("VALIDTEST") == 1:
|
||||
src = re.search(r"=.*\?.*?read_image", prg.prg)
|
||||
if src is not None: raise Exception("Openpilot has valid checks!")
|
||||
|
||||
global_size = prg.global_size + [1]*(3-len(prg.global_size))
|
||||
local_size = prg.local_size + [1]*(3-len(prg.local_size))
|
||||
cl_cache.append((prg.clprg, [[g*l for g,l in zip(global_size, local_size)], local_size, *[x._buf for x in args]]))
|
||||
|
||||
@@ -130,6 +130,10 @@ class TestSymbolic(unittest.TestCase):
|
||||
# NOTE: even though the mod max is 50, it can't know this without knowing about the mul
|
||||
self.helper_test_variable(Variable.sum([Variable("a", 0, 7)*100, Variable("b", 0, 3)*50]) % 100, 0, 99, "((b*50)%100)")
|
||||
|
||||
def test_mod_to_sub(self):
|
||||
# This is mod reduction
|
||||
self.helper_test_variable((1+Variable("a",1,2))%2, 0, 1, (Variable("a",1,2)-1).render())
|
||||
|
||||
def test_sum_div_const(self):
|
||||
self.helper_test_variable(Variable.sum([Variable("a", 0, 7)*4, Variable.num(3)]) // 4, 0, 7, "a")
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@ import itertools, math, functools
|
||||
from collections import defaultdict
|
||||
from enum import Enum, auto
|
||||
|
||||
from tinygrad.helpers import colored, ImageDType, DEBUG, dtypes, DType, partition, prod, PtrDType, all_same
|
||||
from tinygrad.helpers import colored, ImageDType, DEBUG, dtypes, DType, prod, PtrDType, all_same
|
||||
from tinygrad.ops import LazyOp, UnaryOps
|
||||
from tinygrad.ops import ReduceOps, BinaryOps, TernaryOps
|
||||
from tinygrad.runtime.lib import RawConst
|
||||
@@ -21,25 +21,86 @@ class UOps(Enum):
|
||||
LOAD = auto(); STORE = auto(); CONST = auto(); BARRIER = auto() # noqa: E702
|
||||
ALU = auto(); WMMA = auto(); CAST = auto(); GEP = auto() # noqa: E702
|
||||
|
||||
def to_image_idx(base_shape:Tuple[int, ...], idxy:Node, valid:Node, validhacks=False) -> Tuple[Node, Node]:
|
||||
idy = (idxy//(4*base_shape[1]))
|
||||
if validhacks and valid.min == 0:
|
||||
idx = (idxy//4) + (idy*-base_shape[1])
|
||||
# find the ones in idx that didn't factorize and remove them (TODO: this is not universal)
|
||||
if isinstance(idx, SumNode):
|
||||
unfactored, idx_nodes = partition(idx.nodes, lambda x: isinstance(x, MulNode) and x.b == -base_shape[1])
|
||||
assert len(unfactored) <= 1
|
||||
idx = Variable.sum(idx_nodes)
|
||||
unfactored = (Variable.sum(unfactored) // base_shape[1])
|
||||
idy += unfactored
|
||||
# ugh really...handtuned garbage
|
||||
if idx.min >= (base_shape[1]*3)//4:
|
||||
idx -= base_shape[1]
|
||||
idy += 1
|
||||
else:
|
||||
idx = (idxy//4)%base_shape[1]
|
||||
if DEBUG >= 5: print("to_image_idx", base_shape, idx.min, idx.max, idy.min, idy.max, idx, idy)
|
||||
return idx, idy
|
||||
def to_image_idx(base_shape:Tuple[int, ...], idxy:Node, valid:Node) -> Tuple[Tuple[Node, Node], Node]:
|
||||
# This part is substituting variables by just looking at single var LtNodes in valid
|
||||
# Basically if var[0-5] < 3 -> var[0-2]
|
||||
if valid.min == 0:
|
||||
nodes: List = valid.nodes if isinstance(valid, AndNode) else [valid]
|
||||
var_dict = {var:[var.min, var.max] for var in valid.vars()}
|
||||
|
||||
for nd in nodes:
|
||||
var_range = var_dict[nd.vars()[0]]
|
||||
if isinstance(nd.a, MulNode):
|
||||
if nd.a.b < 0:
|
||||
var_range[0] = (nd.b // nd.a.b) + 1
|
||||
elif nd.a.b > 0:
|
||||
var_range[1] = (nd.b // nd.a.b) - 1 if nd.b % nd.a.b == 0 else nd.b // nd.a.b
|
||||
elif isinstance(nd.a, Variable):
|
||||
var_range[1] = nd.b - 1
|
||||
# We do not allow NumNode because it is constant
|
||||
# TODO: Remove mx != mn
|
||||
sub_dict: dict[Union[Variable, NumNode], Node] = {v:Variable(v.expr, mn, mx) for v, (mn, mx) in var_dict.items() if mx != mn}
|
||||
valid, idxy = valid.substitute(sub_dict), idxy.substitute(sub_dict)
|
||||
|
||||
idx, idy = (idxy // 4) % base_shape[1], (idxy // (4 * base_shape[1]))
|
||||
idx_vars, idy_vars, val_vars = set(idx.vars()), set(idy.vars()), set(valid.vars())
|
||||
|
||||
# Simplify ModNode if possibe # test_padded_conv_transpose2d, Needs much more thinking
|
||||
if valid.min == 0 and isinstance(idx, ModNode) and isinstance(idx.a, SumNode):
|
||||
nodes = valid.nodes if isinstance(valid, AndNode) else [valid]
|
||||
same_dict: Dict[Node, List[Tuple[int, Node]]] = {}
|
||||
idx_nodes = idx.a.flat_components
|
||||
|
||||
for node in nodes:
|
||||
if not isinstance(node, LtNode) or not isinstance(node.a, SumNode): continue
|
||||
|
||||
nd_flat, nd_vars = node.a.flat_components, node.vars()
|
||||
|
||||
same = [x for x in idx_nodes if (x.a if isinstance(x, MulNode) else x) in nd_vars]
|
||||
|
||||
if len(same) != len(nd_vars): continue
|
||||
|
||||
first_b, second_b = nd_flat[0].b if isinstance(nd_flat[0], MulNode) else 1, same[0].b if isinstance(same[0], MulNode) else 1
|
||||
k, same_sum = second_b//first_b, Variable.sum(same)
|
||||
|
||||
if k*(node.a) == same_sum: same_dict[same_sum] = same_dict.get(same_sum, []) + [(k, node)]
|
||||
|
||||
for key in same_dict.keys():
|
||||
same, mnn, mxn = key.flat_components, key.min, key.max # type: ignore # Same is sumnode because node.a is SumNode
|
||||
for k, node in same_dict[key]: # TODO: This part may need more thinking
|
||||
if k < 0: mnn = (-k)*max((-node.b) + 1, min([-lal.b if isinstance(lal, MulNode) else 1 for lal in same]))
|
||||
else: mxn = (node.b - 1)*k
|
||||
|
||||
fake_var = Variable("valid_fake", mnn, mxn)
|
||||
total = (Variable.sum([x for x in idx_nodes if x not in same]) + fake_var) % idx.b
|
||||
idx = total.substitute({fake_var: key})
|
||||
# TODO: If idx has no ModNode we may can remove the valid node, but removing it needs careful thinking
|
||||
|
||||
# Simplify SumNodes
|
||||
# This part just removes valid nodes if node is exactly same as idx or idy
|
||||
# idx = 3*a + b (+ 5), valid = 3*a + b < 10 # Valid will be removed as idx will go out of bounds
|
||||
# Check for var intersection, removing valid can affect other index
|
||||
if valid.min == 0 and not idx_vars.intersection(idy_vars):
|
||||
nds = valid.nodes if isinstance(valid, AndNode) else [valid]
|
||||
flats = [id.flat_components for id in (idx, idy) if isinstance(id, SumNode)]
|
||||
sym_sums = [Variable.sum([i for i in flat if not isinstance(i, NumNode)]) for flat in flats]
|
||||
ones = [node for sym_sum in sym_sums for node in nds if (node.a == sym_sum) or (-(node.a) == sym_sum)] # type: ignore # AndNode always consists of LtNode
|
||||
valid = Variable.ands([i for i in nds if i not in ones])
|
||||
|
||||
# This is the slow part
|
||||
# This part is for brute forcing all possible values of idx, idy and valid
|
||||
# If valid is both 0 and 1 for the same (idx, idy) we can not delete the valid
|
||||
if valid.min == 0 and not isinstance(idx, ModNode):
|
||||
variables = tuple(val_vars | idy_vars | idx_vars)
|
||||
val_infer, idx_infer, idy_infer = valid.expand(variables), idx.expand(variables), idy.expand(variables)
|
||||
val_dict: Dict[int, Set[Tuple[int,int]]] = {0:set(), 1:set()}
|
||||
|
||||
for v, x, y in zip(val_infer, idx_infer, idy_infer): val_dict[v.min].add((x.min, y.min))
|
||||
|
||||
if not val_dict[1].intersection(val_dict[0]): valid = NumNode(1)
|
||||
|
||||
if DEBUG>=5: print("to_image_idx", base_shape, idx.min, idx.max, idy.min, idy.max, idx, idy)
|
||||
return (idx, idy), valid
|
||||
|
||||
class UOp(NamedTuple):
|
||||
uop: UOps
|
||||
@@ -129,10 +190,11 @@ class Linearizer(OptimizedKernel):
|
||||
buf_uop = self.buf_uops[i]
|
||||
assert buf_uop is not None, f"buffer {i} wasn't UOped"
|
||||
if isinstance(self.bufs[i].dtype, ImageDType):
|
||||
idx = to_image_idx(self.bufs[i].dtype.shape, idx, valid)
|
||||
idx, valid = to_image_idx(self.bufs[i].dtype.shape, idx, valid)
|
||||
rendered_idx = self.uop(UOps.CAST, dtypes._int2, (idx[0].render(self.render_ops, self), idx[1].render(self.render_ops, self)))
|
||||
else:
|
||||
rendered_idx = idx.render(self.render_ops, self)
|
||||
|
||||
if valid.min == 0:
|
||||
valid_rendered = valid.render(self.render_ops, self)
|
||||
self.load_cache[key] = self.uop(UOps.LOAD, localtype, (buf_uop, rendered_idx, valid_rendered, self.const(invalid_value, localtype)))
|
||||
@@ -168,7 +230,7 @@ class Linearizer(OptimizedKernel):
|
||||
for idx, var in store_offset.items():
|
||||
idx, valid = self.sts[i].expr_idxs(idx)
|
||||
if isinstance(self.bufs[i].dtype, ImageDType):
|
||||
idx = to_image_idx(self.bufs[i].dtype.shape, idx, valid)
|
||||
idx, valid = to_image_idx(self.bufs[i].dtype.shape, idx, valid)
|
||||
rendered_idx = self.uop(UOps.CAST, dtypes._int2, tuple(x.render(self.render_ops, self) for x in idx))
|
||||
else:
|
||||
rendered_idx = idx.render(self.render_ops, self)
|
||||
|
||||
@@ -97,6 +97,7 @@ class Node:
|
||||
assert b > 0
|
||||
if b == 1: return NumNode(0)
|
||||
if self.min >= 0 and self.max < b: return self
|
||||
if (self.min//b) == (self.max//b): return self - (b*(self.min//b))
|
||||
if self.min < 0: return (self - ((self.min//b)*b)) % b
|
||||
return create_node(ModNode(self, b))
|
||||
|
||||
|
||||
Reference in New Issue
Block a user