Fix VALIDHACKS for Images and make it default (#1832)

* valid hacks

* valid hacks

* valid hacks

* new method

* new method

* handtune

* is gate load breaking?

* lint

ruff

less junk

new approach?

maybe this?

* Make it more clear

* Make it more clear

* Will deal with the linter later

* hack for linter

* subs the idx but dont touch the valid

* Updated the mod rules

* lint hack

* I believe bug fix lets see

* Mod Node left

* revert

* Maybe this wont break?

* revert

* implemented "handtuned garbage"

* revert and use VALIDHACKS

* Lets see the CI

* still broken?

* currently its jungle

* maybe this jungle ?

* This works for everything somehow

* Added test for symbolic

* lint

* final touch

* This still works

* lint

* midway clean

* less garbage

* lint

* final form

* Slow but working way

* lint and other stuff

* lint

* mypy

* Make sure CI test Openpilot valid checks

* test if CI break

* Convert back

* refactor

* refactor

* Managed to reduce openpilot time from 30 secs to 5 secs

* Refactor

* Substitute a node with variable

* flake8

* Comment and refactor

* More comprehensive mod

* refactor

* bug fix

* More shave off

* remove not sure part
This commit is contained in:
Umut Zengin
2023-09-23 02:34:43 +03:00
committed by GitHub
parent 767bb35903
commit 3987280daf
5 changed files with 96 additions and 24 deletions

View File

@@ -147,7 +147,7 @@ jobs:
- if: ${{ matrix.task == 'openpilot' }}
name: Test openpilot model compile and size
run: |
DEBUG=2 ALLOWED_KERNEL_COUNT=209 FLOAT16=1 DEBUGCL=1 GPU=1 IMAGE=2 python openpilot/compile.py
DEBUG=2 ALLOWED_KERNEL_COUNT=209 VALIDTEST=1 FLOAT16=1 DEBUGCL=1 GPU=1 IMAGE=2 python openpilot/compile.py
python -c 'import os; assert os.path.getsize("/tmp/output.thneed") < 100_000_000'
- if: ${{ matrix.task == 'openpilot' }}
name: Test openpilot model correctness (float32)

View File

@@ -1,5 +1,5 @@
#!/usr/bin/env python3
import os, time, io, pathlib, sys, traceback
import os, time, io, pathlib, sys, traceback, re
sys.path.insert(0, str(pathlib.Path(__file__).parents[1]))
if os.getenv("OPT", None) is None:
@@ -72,6 +72,11 @@ def compile(dat, output_fn):
# pass these to thneed
setattr(prg.clprg, 'op_estimate', prg.op_estimate)
setattr(prg.clprg, 'prg', prg.prg)
if getenv("VALIDTEST") == 1:
src = re.search(r"=.*\?.*?read_image", prg.prg)
if src is not None: raise Exception("Openpilot has valid checks!")
global_size = prg.global_size + [1]*(3-len(prg.global_size))
local_size = prg.local_size + [1]*(3-len(prg.local_size))
cl_cache.append((prg.clprg, [[g*l for g,l in zip(global_size, local_size)], local_size, *[x._buf for x in args]]))

View File

@@ -130,6 +130,10 @@ class TestSymbolic(unittest.TestCase):
# NOTE: even though the mod max is 50, it can't know this without knowing about the mul
self.helper_test_variable(Variable.sum([Variable("a", 0, 7)*100, Variable("b", 0, 3)*50]) % 100, 0, 99, "((b*50)%100)")
def test_mod_to_sub(self):
# This is mod reduction
self.helper_test_variable((1+Variable("a",1,2))%2, 0, 1, (Variable("a",1,2)-1).render())
def test_sum_div_const(self):
self.helper_test_variable(Variable.sum([Variable("a", 0, 7)*4, Variable.num(3)]) // 4, 0, 7, "a")

View File

@@ -4,7 +4,7 @@ import itertools, math, functools
from collections import defaultdict
from enum import Enum, auto
from tinygrad.helpers import colored, ImageDType, DEBUG, dtypes, DType, partition, prod, PtrDType, all_same
from tinygrad.helpers import colored, ImageDType, DEBUG, dtypes, DType, prod, PtrDType, all_same
from tinygrad.ops import LazyOp, UnaryOps
from tinygrad.ops import ReduceOps, BinaryOps, TernaryOps
from tinygrad.runtime.lib import RawConst
@@ -21,25 +21,86 @@ class UOps(Enum):
LOAD = auto(); STORE = auto(); CONST = auto(); BARRIER = auto() # noqa: E702
ALU = auto(); WMMA = auto(); CAST = auto(); GEP = auto() # noqa: E702
def to_image_idx(base_shape:Tuple[int, ...], idxy:Node, valid:Node, validhacks=False) -> Tuple[Node, Node]:
idy = (idxy//(4*base_shape[1]))
if validhacks and valid.min == 0:
idx = (idxy//4) + (idy*-base_shape[1])
# find the ones in idx that didn't factorize and remove them (TODO: this is not universal)
if isinstance(idx, SumNode):
unfactored, idx_nodes = partition(idx.nodes, lambda x: isinstance(x, MulNode) and x.b == -base_shape[1])
assert len(unfactored) <= 1
idx = Variable.sum(idx_nodes)
unfactored = (Variable.sum(unfactored) // base_shape[1])
idy += unfactored
# ugh really...handtuned garbage
if idx.min >= (base_shape[1]*3)//4:
idx -= base_shape[1]
idy += 1
else:
idx = (idxy//4)%base_shape[1]
if DEBUG >= 5: print("to_image_idx", base_shape, idx.min, idx.max, idy.min, idy.max, idx, idy)
return idx, idy
def to_image_idx(base_shape:Tuple[int, ...], idxy:Node, valid:Node) -> Tuple[Tuple[Node, Node], Node]:
# This part is substituting variables by just looking at single var LtNodes in valid
# Basically if var[0-5] < 3 -> var[0-2]
if valid.min == 0:
nodes: List = valid.nodes if isinstance(valid, AndNode) else [valid]
var_dict = {var:[var.min, var.max] for var in valid.vars()}
for nd in nodes:
var_range = var_dict[nd.vars()[0]]
if isinstance(nd.a, MulNode):
if nd.a.b < 0:
var_range[0] = (nd.b // nd.a.b) + 1
elif nd.a.b > 0:
var_range[1] = (nd.b // nd.a.b) - 1 if nd.b % nd.a.b == 0 else nd.b // nd.a.b
elif isinstance(nd.a, Variable):
var_range[1] = nd.b - 1
# We do not allow NumNode because it is constant
# TODO: Remove mx != mn
sub_dict: dict[Union[Variable, NumNode], Node] = {v:Variable(v.expr, mn, mx) for v, (mn, mx) in var_dict.items() if mx != mn}
valid, idxy = valid.substitute(sub_dict), idxy.substitute(sub_dict)
idx, idy = (idxy // 4) % base_shape[1], (idxy // (4 * base_shape[1]))
idx_vars, idy_vars, val_vars = set(idx.vars()), set(idy.vars()), set(valid.vars())
# Simplify ModNode if possibe # test_padded_conv_transpose2d, Needs much more thinking
if valid.min == 0 and isinstance(idx, ModNode) and isinstance(idx.a, SumNode):
nodes = valid.nodes if isinstance(valid, AndNode) else [valid]
same_dict: Dict[Node, List[Tuple[int, Node]]] = {}
idx_nodes = idx.a.flat_components
for node in nodes:
if not isinstance(node, LtNode) or not isinstance(node.a, SumNode): continue
nd_flat, nd_vars = node.a.flat_components, node.vars()
same = [x for x in idx_nodes if (x.a if isinstance(x, MulNode) else x) in nd_vars]
if len(same) != len(nd_vars): continue
first_b, second_b = nd_flat[0].b if isinstance(nd_flat[0], MulNode) else 1, same[0].b if isinstance(same[0], MulNode) else 1
k, same_sum = second_b//first_b, Variable.sum(same)
if k*(node.a) == same_sum: same_dict[same_sum] = same_dict.get(same_sum, []) + [(k, node)]
for key in same_dict.keys():
same, mnn, mxn = key.flat_components, key.min, key.max # type: ignore # Same is sumnode because node.a is SumNode
for k, node in same_dict[key]: # TODO: This part may need more thinking
if k < 0: mnn = (-k)*max((-node.b) + 1, min([-lal.b if isinstance(lal, MulNode) else 1 for lal in same]))
else: mxn = (node.b - 1)*k
fake_var = Variable("valid_fake", mnn, mxn)
total = (Variable.sum([x for x in idx_nodes if x not in same]) + fake_var) % idx.b
idx = total.substitute({fake_var: key})
# TODO: If idx has no ModNode we may can remove the valid node, but removing it needs careful thinking
# Simplify SumNodes
# This part just removes valid nodes if node is exactly same as idx or idy
# idx = 3*a + b (+ 5), valid = 3*a + b < 10 # Valid will be removed as idx will go out of bounds
# Check for var intersection, removing valid can affect other index
if valid.min == 0 and not idx_vars.intersection(idy_vars):
nds = valid.nodes if isinstance(valid, AndNode) else [valid]
flats = [id.flat_components for id in (idx, idy) if isinstance(id, SumNode)]
sym_sums = [Variable.sum([i for i in flat if not isinstance(i, NumNode)]) for flat in flats]
ones = [node for sym_sum in sym_sums for node in nds if (node.a == sym_sum) or (-(node.a) == sym_sum)] # type: ignore # AndNode always consists of LtNode
valid = Variable.ands([i for i in nds if i not in ones])
# This is the slow part
# This part is for brute forcing all possible values of idx, idy and valid
# If valid is both 0 and 1 for the same (idx, idy) we can not delete the valid
if valid.min == 0 and not isinstance(idx, ModNode):
variables = tuple(val_vars | idy_vars | idx_vars)
val_infer, idx_infer, idy_infer = valid.expand(variables), idx.expand(variables), idy.expand(variables)
val_dict: Dict[int, Set[Tuple[int,int]]] = {0:set(), 1:set()}
for v, x, y in zip(val_infer, idx_infer, idy_infer): val_dict[v.min].add((x.min, y.min))
if not val_dict[1].intersection(val_dict[0]): valid = NumNode(1)
if DEBUG>=5: print("to_image_idx", base_shape, idx.min, idx.max, idy.min, idy.max, idx, idy)
return (idx, idy), valid
class UOp(NamedTuple):
uop: UOps
@@ -129,10 +190,11 @@ class Linearizer(OptimizedKernel):
buf_uop = self.buf_uops[i]
assert buf_uop is not None, f"buffer {i} wasn't UOped"
if isinstance(self.bufs[i].dtype, ImageDType):
idx = to_image_idx(self.bufs[i].dtype.shape, idx, valid)
idx, valid = to_image_idx(self.bufs[i].dtype.shape, idx, valid)
rendered_idx = self.uop(UOps.CAST, dtypes._int2, (idx[0].render(self.render_ops, self), idx[1].render(self.render_ops, self)))
else:
rendered_idx = idx.render(self.render_ops, self)
if valid.min == 0:
valid_rendered = valid.render(self.render_ops, self)
self.load_cache[key] = self.uop(UOps.LOAD, localtype, (buf_uop, rendered_idx, valid_rendered, self.const(invalid_value, localtype)))
@@ -168,7 +230,7 @@ class Linearizer(OptimizedKernel):
for idx, var in store_offset.items():
idx, valid = self.sts[i].expr_idxs(idx)
if isinstance(self.bufs[i].dtype, ImageDType):
idx = to_image_idx(self.bufs[i].dtype.shape, idx, valid)
idx, valid = to_image_idx(self.bufs[i].dtype.shape, idx, valid)
rendered_idx = self.uop(UOps.CAST, dtypes._int2, tuple(x.render(self.render_ops, self) for x in idx))
else:
rendered_idx = idx.render(self.render_ops, self)

View File

@@ -97,6 +97,7 @@ class Node:
assert b > 0
if b == 1: return NumNode(0)
if self.min >= 0 and self.max < b: return self
if (self.min//b) == (self.max//b): return self - (b*(self.min//b))
if self.min < 0: return (self - ((self.min//b)*b)) % b
return create_node(ModNode(self, b))