add UOps.VCONST [run_process_replay] (#6487)

* add UOps.VCONST [run_process_replay]

* VCONST folding

* simpler devectorize

* alu

* revert that type
This commit is contained in:
George Hotz
2024-09-12 14:03:39 +08:00
committed by GitHub
parent 4dc9436d63
commit 119b0ea4af
4 changed files with 43 additions and 9 deletions

View File

@@ -67,6 +67,28 @@ class TestGraphRewriteEfficiency(unittest.TestCase):
#from tinygrad.engine.graph import graph_uops
#graph_uops(linearize_uop(new_sink))
class TestGraphRewriteConst(unittest.TestCase):
def test_gep_const(self):
v1 = UOp.const(dtypes.int.vec(3), (0,1,2))
v2 = v1.gep(1)
ret = graph_rewrite(v2, constant_folder)
self.assertEqual(ret.dtype, dtypes.int)
self.assertEqual(ret.arg, 1)
def test_gep_const_single(self):
v1 = UOp.const(dtypes.int.vec(3), 4)
v2 = v1.gep(1)
ret = graph_rewrite(v2, constant_folder)
self.assertEqual(ret.dtype, dtypes.int)
self.assertEqual(ret.arg, 4)
def test_add_const(self):
v1 = UOp.const(dtypes.int.vec(3), (0,1,2))
v2 = UOp.const(dtypes.int.vec(3), (5,6,7))
ret = graph_rewrite(v1+v2, constant_folder)
self.assertEqual(ret.dtype, dtypes.int.vec(3))
self.assertEqual(ret.arg, (5,7,9))
class TestGraphRewrite(unittest.TestCase):
def test_dedup(self):
v1 = UOp(UOps.DEFINE_VAR, dtypes.float)

View File

@@ -274,12 +274,14 @@ constant_folder = PatternMatcher([
(UPat.max(UPat.var('x'), UPat.var('y')), lambda x,y: x if x.vmin >= y.vmax else y if x.vmax <= y.vmin else None),
# GEP/CAST const rules
(UPat(UOps.GEP, src=(UPat.cvar("c"),), name="root"), lambda root, c: root.const_like(c.arg)),
(UPat(UOps.GEP, src=(UPat(UOps.VCONST, name="c"),), name="root"), lambda root, c: root.const_like(c.arg[root.arg])),
(UPat(UOps.CAST, name="root", src=UPat(UOps.CONST, name="c")), lambda root, c: root.const_like(c.arg)),
# a conditional with the same results either way is a noop, also fold const conditionals
(UPat.var().where(UPat.var("val"), UPat.var("val")), lambda val: val),
(UPat.cvar('gate').where(UPat.var('c0'), UPat.var('c1')), lambda gate, c0, c1: c0 if gate.arg else c1),
# ** constant folding **
(UPat(UOps.ALU, name="root", src=UPat(UOps.CONST)), lambda root: root.const_like(exec_alu(root.arg, root.dtype, [x.arg for x in root.src]))),
(UPat(UOps.ALU, name="root", src=UPat((UOps.VCONST, UOps.CONST))),
lambda root: root.const_like(exec_alu(root.arg, root.dtype, [x.arg for x in root.src]))),
# ** self folding **
# cast NOOP (NOTE: it's str to deal with PtrDType)
(UPat(UOps.CAST, name="root"), lambda root: root.src[0] if str(root.dtype) == str(root.src[0].dtype) else None),
@@ -477,6 +479,9 @@ def delete_redundant_gates(root:UOp) -> Optional[UOp]:
reducer = PatternMatcher([
(UPat(UOps.REDUCE, name="root"), do_reduce),
(UPat(UOps.CONST, name='c'),
lambda c: UOp(UOps.VECTORIZE, c.dtype, (UOp.const(c.dtype.scalar(), c.arg),)*c.dtype.count) if c.dtype.count > 1 else None),
(UPat(UOps.VCONST, name='c'), lambda c: UOp(UOps.VECTORIZE, c.dtype, tuple(UOp.const(c.dtype.scalar(), x) for x in c.arg))),
# no ALU on vectorized dtypes
(UPat((UOps.ALU, UOps.CAST, UOps.BITCAST), name="alu"), no_vectorized_alu),
# delete_redundant_gates (after expand, is this still needed?)

View File

@@ -55,7 +55,11 @@ class dtypes:
if x.__class__ is list or x.__class__ is tuple: return max(dtypes.from_py(xi) for xi in x) if x else dtypes.default_float
raise RuntimeError(f"Could not infer dtype of {x} with type {type(x)}")
@staticmethod
def as_const(val: ConstType, dtype:DType): return int(val) if dtypes.is_int(dtype) else float(val) if dtypes.is_float(dtype) else bool(val)
def as_const(val: Tuple[ConstType, ...]|ConstType, dtype:DType):
if isinstance(val, tuple):
assert len(val) == dtype.count, f"mismatch {val} {dtype}"
return tuple(dtypes.as_const(x, dtype) for x in val)
return int(val) if dtypes.is_int(dtype) else float(val) if dtypes.is_float(dtype) else bool(val)
@staticmethod
def min(dtype:DType):
if dtypes.is_int(dtype): return 0 if dtypes.is_unsigned(dtype) else -2**(dtype.itemsize*8-1)

View File

@@ -393,14 +393,12 @@ class UOp(MathTrait):
return type(self)(UOps.ALU, out_dtype, (self,)+src, arg)
@classmethod
@functools.lru_cache(None)
def const(cls, dtype:DType, b:ConstType|Variable): return cls._const(dtype, b)
def const(cls, dtype:DType, b:Tuple[ConstType, ...]|ConstType|Variable): return cls._const(dtype, b)
@classmethod
def _const(cls, dtype:DType, b:ConstType|Variable):
def _const(cls, dtype:DType, b:Tuple[ConstType, ...]|ConstType|Variable):
# TODO: fix dtype of b.max after Variable is just an UOp
if isinstance(b, Variable): return cls(UOps.DEFINE_VAR, dtype, arg=(b.expr, cls.const(dtypes.int, b.min), cls.const(dtypes.int, cast(int,b.max))))
if dtype is not None and dtype != (sdtype := dtype.scalar()):
return cls(UOps.VECTORIZE, dtype, src=tuple(cls(UOps.CONST, sdtype, arg=dtypes.as_const(b, sdtype)) for _ in range(dtype.count)))
return cls(UOps.CONST, dtype, arg=dtypes.as_const(b, dtype) if dtype is not None else b)
if isinstance(b, Variable): return cls(UOps.DEFINE_VAR, dtype, arg=(b.expr, cls.const(dtypes.int, b.min), cls.const(dtypes.int, cast(int,b.max)))) # type: ignore
return cls(UOps.VCONST if isinstance(b, tuple) else UOps.CONST, dtype, arg=dtypes.as_const(b, dtype) if dtype is not None else b) # type: ignore
@functools.cached_property
def parents(self) -> Dict[UOp, None]: return {**{x:None for x in self.src}, **{k:None for x in self.src for k in x.parents.keys()}}
@property # parents with self
@@ -417,6 +415,7 @@ class UOp(MathTrait):
def const_factor(self) -> int:
"""largest known int that divides self"""
if self.op is UOps.CONST: return self.arg
if self.op is UOps.VCONST: return math.gcd(*self.arg)
if self.op is UOps.ALU:
if self.arg is BinaryOps.ADD: return math.gcd(self.src[0].const_factor(), self.src[1].const_factor())
if self.arg is BinaryOps.MUL: return self.src[0].arg if self.src[0].op is UOps.CONST else self.src[1].arg if self.src[1].op is UOps.CONST else 1
@@ -444,6 +443,7 @@ class UOp(MathTrait):
# TODO: UOps.SPECIAL is UOps.DEFINE_VAR
if self.op is UOps.SPECIAL: return 0, self.arg[1]-1 if isinstance(self.arg[1], int) else dtypes.max(self.dtype)
if self.op is UOps.CONST: return self.arg, self.arg
if self.op is UOps.VCONST: return (min(self.arg), max(self.arg))
if self.op is UOps.ALU and self.dtype.count == 1:
s0,s1 = [cast(UOp, self.src[i] if i < len(self.src) else None) for i in range(2)]
if self.arg is BinaryOps.ADD: return s0.vmin+s1.vmin, s0.vmax+s1.vmax
@@ -500,7 +500,10 @@ truncate: Dict[DType, Callable] = {dtypes.bool: bool,
dtypes.int8: lambda x: ctypes.c_int8(x).value, dtypes.int16: lambda x: ctypes.c_int16(x).value, dtypes.int32: lambda x: ctypes.c_int32(x).value \
if isinstance(x,int) else x, dtypes.int64: lambda x: ctypes.c_int64(x).value}
def exec_alu(op:Op, dtype:DType, operands): return truncate.get(dtype, lambda x: x)(python_alu[op](*operands))
def exec_alu(op:Op, dtype:DType, operands):
if dtype.count > 1:
return tuple([exec_alu(op, dtype.scalar(), [x[i] if isinstance(x, tuple) else x for x in operands]) for i in range(dtype.count)])
return truncate.get(dtype, lambda x: x)(python_alu[op](*operands))
def uop_alu_resolve(u:UOp) -> sint:
if u.op is UOps.CONST: return u.arg