mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
add UOps.VCONST [run_process_replay] (#6487)
* add UOps.VCONST [run_process_replay] * VCONST folding * simpler devectorize * alu * revert that type
This commit is contained in:
@@ -67,6 +67,28 @@ class TestGraphRewriteEfficiency(unittest.TestCase):
|
||||
#from tinygrad.engine.graph import graph_uops
|
||||
#graph_uops(linearize_uop(new_sink))
|
||||
|
||||
class TestGraphRewriteConst(unittest.TestCase):
|
||||
def test_gep_const(self):
|
||||
v1 = UOp.const(dtypes.int.vec(3), (0,1,2))
|
||||
v2 = v1.gep(1)
|
||||
ret = graph_rewrite(v2, constant_folder)
|
||||
self.assertEqual(ret.dtype, dtypes.int)
|
||||
self.assertEqual(ret.arg, 1)
|
||||
|
||||
def test_gep_const_single(self):
|
||||
v1 = UOp.const(dtypes.int.vec(3), 4)
|
||||
v2 = v1.gep(1)
|
||||
ret = graph_rewrite(v2, constant_folder)
|
||||
self.assertEqual(ret.dtype, dtypes.int)
|
||||
self.assertEqual(ret.arg, 4)
|
||||
|
||||
def test_add_const(self):
|
||||
v1 = UOp.const(dtypes.int.vec(3), (0,1,2))
|
||||
v2 = UOp.const(dtypes.int.vec(3), (5,6,7))
|
||||
ret = graph_rewrite(v1+v2, constant_folder)
|
||||
self.assertEqual(ret.dtype, dtypes.int.vec(3))
|
||||
self.assertEqual(ret.arg, (5,7,9))
|
||||
|
||||
class TestGraphRewrite(unittest.TestCase):
|
||||
def test_dedup(self):
|
||||
v1 = UOp(UOps.DEFINE_VAR, dtypes.float)
|
||||
|
||||
@@ -274,12 +274,14 @@ constant_folder = PatternMatcher([
|
||||
(UPat.max(UPat.var('x'), UPat.var('y')), lambda x,y: x if x.vmin >= y.vmax else y if x.vmax <= y.vmin else None),
|
||||
# GEP/CAST const rules
|
||||
(UPat(UOps.GEP, src=(UPat.cvar("c"),), name="root"), lambda root, c: root.const_like(c.arg)),
|
||||
(UPat(UOps.GEP, src=(UPat(UOps.VCONST, name="c"),), name="root"), lambda root, c: root.const_like(c.arg[root.arg])),
|
||||
(UPat(UOps.CAST, name="root", src=UPat(UOps.CONST, name="c")), lambda root, c: root.const_like(c.arg)),
|
||||
# a conditional with the same results either way is a noop, also fold const conditionals
|
||||
(UPat.var().where(UPat.var("val"), UPat.var("val")), lambda val: val),
|
||||
(UPat.cvar('gate').where(UPat.var('c0'), UPat.var('c1')), lambda gate, c0, c1: c0 if gate.arg else c1),
|
||||
# ** constant folding **
|
||||
(UPat(UOps.ALU, name="root", src=UPat(UOps.CONST)), lambda root: root.const_like(exec_alu(root.arg, root.dtype, [x.arg for x in root.src]))),
|
||||
(UPat(UOps.ALU, name="root", src=UPat((UOps.VCONST, UOps.CONST))),
|
||||
lambda root: root.const_like(exec_alu(root.arg, root.dtype, [x.arg for x in root.src]))),
|
||||
# ** self folding **
|
||||
# cast NOOP (NOTE: it's str to deal with PtrDType)
|
||||
(UPat(UOps.CAST, name="root"), lambda root: root.src[0] if str(root.dtype) == str(root.src[0].dtype) else None),
|
||||
@@ -477,6 +479,9 @@ def delete_redundant_gates(root:UOp) -> Optional[UOp]:
|
||||
|
||||
reducer = PatternMatcher([
|
||||
(UPat(UOps.REDUCE, name="root"), do_reduce),
|
||||
(UPat(UOps.CONST, name='c'),
|
||||
lambda c: UOp(UOps.VECTORIZE, c.dtype, (UOp.const(c.dtype.scalar(), c.arg),)*c.dtype.count) if c.dtype.count > 1 else None),
|
||||
(UPat(UOps.VCONST, name='c'), lambda c: UOp(UOps.VECTORIZE, c.dtype, tuple(UOp.const(c.dtype.scalar(), x) for x in c.arg))),
|
||||
# no ALU on vectorized dtypes
|
||||
(UPat((UOps.ALU, UOps.CAST, UOps.BITCAST), name="alu"), no_vectorized_alu),
|
||||
# delete_redundant_gates (after expand, is this still needed?)
|
||||
|
||||
@@ -55,7 +55,11 @@ class dtypes:
|
||||
if x.__class__ is list or x.__class__ is tuple: return max(dtypes.from_py(xi) for xi in x) if x else dtypes.default_float
|
||||
raise RuntimeError(f"Could not infer dtype of {x} with type {type(x)}")
|
||||
@staticmethod
|
||||
def as_const(val: ConstType, dtype:DType): return int(val) if dtypes.is_int(dtype) else float(val) if dtypes.is_float(dtype) else bool(val)
|
||||
def as_const(val: Tuple[ConstType, ...]|ConstType, dtype:DType):
|
||||
if isinstance(val, tuple):
|
||||
assert len(val) == dtype.count, f"mismatch {val} {dtype}"
|
||||
return tuple(dtypes.as_const(x, dtype) for x in val)
|
||||
return int(val) if dtypes.is_int(dtype) else float(val) if dtypes.is_float(dtype) else bool(val)
|
||||
@staticmethod
|
||||
def min(dtype:DType):
|
||||
if dtypes.is_int(dtype): return 0 if dtypes.is_unsigned(dtype) else -2**(dtype.itemsize*8-1)
|
||||
|
||||
@@ -393,14 +393,12 @@ class UOp(MathTrait):
|
||||
return type(self)(UOps.ALU, out_dtype, (self,)+src, arg)
|
||||
@classmethod
|
||||
@functools.lru_cache(None)
|
||||
def const(cls, dtype:DType, b:ConstType|Variable): return cls._const(dtype, b)
|
||||
def const(cls, dtype:DType, b:Tuple[ConstType, ...]|ConstType|Variable): return cls._const(dtype, b)
|
||||
@classmethod
|
||||
def _const(cls, dtype:DType, b:ConstType|Variable):
|
||||
def _const(cls, dtype:DType, b:Tuple[ConstType, ...]|ConstType|Variable):
|
||||
# TODO: fix dtype of b.max after Variable is just an UOp
|
||||
if isinstance(b, Variable): return cls(UOps.DEFINE_VAR, dtype, arg=(b.expr, cls.const(dtypes.int, b.min), cls.const(dtypes.int, cast(int,b.max))))
|
||||
if dtype is not None and dtype != (sdtype := dtype.scalar()):
|
||||
return cls(UOps.VECTORIZE, dtype, src=tuple(cls(UOps.CONST, sdtype, arg=dtypes.as_const(b, sdtype)) for _ in range(dtype.count)))
|
||||
return cls(UOps.CONST, dtype, arg=dtypes.as_const(b, dtype) if dtype is not None else b)
|
||||
if isinstance(b, Variable): return cls(UOps.DEFINE_VAR, dtype, arg=(b.expr, cls.const(dtypes.int, b.min), cls.const(dtypes.int, cast(int,b.max)))) # type: ignore
|
||||
return cls(UOps.VCONST if isinstance(b, tuple) else UOps.CONST, dtype, arg=dtypes.as_const(b, dtype) if dtype is not None else b) # type: ignore
|
||||
@functools.cached_property
|
||||
def parents(self) -> Dict[UOp, None]: return {**{x:None for x in self.src}, **{k:None for x in self.src for k in x.parents.keys()}}
|
||||
@property # parents with self
|
||||
@@ -417,6 +415,7 @@ class UOp(MathTrait):
|
||||
def const_factor(self) -> int:
|
||||
"""largest known int that divides self"""
|
||||
if self.op is UOps.CONST: return self.arg
|
||||
if self.op is UOps.VCONST: return math.gcd(*self.arg)
|
||||
if self.op is UOps.ALU:
|
||||
if self.arg is BinaryOps.ADD: return math.gcd(self.src[0].const_factor(), self.src[1].const_factor())
|
||||
if self.arg is BinaryOps.MUL: return self.src[0].arg if self.src[0].op is UOps.CONST else self.src[1].arg if self.src[1].op is UOps.CONST else 1
|
||||
@@ -444,6 +443,7 @@ class UOp(MathTrait):
|
||||
# TODO: UOps.SPECIAL is UOps.DEFINE_VAR
|
||||
if self.op is UOps.SPECIAL: return 0, self.arg[1]-1 if isinstance(self.arg[1], int) else dtypes.max(self.dtype)
|
||||
if self.op is UOps.CONST: return self.arg, self.arg
|
||||
if self.op is UOps.VCONST: return (min(self.arg), max(self.arg))
|
||||
if self.op is UOps.ALU and self.dtype.count == 1:
|
||||
s0,s1 = [cast(UOp, self.src[i] if i < len(self.src) else None) for i in range(2)]
|
||||
if self.arg is BinaryOps.ADD: return s0.vmin+s1.vmin, s0.vmax+s1.vmax
|
||||
@@ -500,7 +500,10 @@ truncate: Dict[DType, Callable] = {dtypes.bool: bool,
|
||||
dtypes.int8: lambda x: ctypes.c_int8(x).value, dtypes.int16: lambda x: ctypes.c_int16(x).value, dtypes.int32: lambda x: ctypes.c_int32(x).value \
|
||||
if isinstance(x,int) else x, dtypes.int64: lambda x: ctypes.c_int64(x).value}
|
||||
|
||||
def exec_alu(op:Op, dtype:DType, operands): return truncate.get(dtype, lambda x: x)(python_alu[op](*operands))
|
||||
def exec_alu(op:Op, dtype:DType, operands):
|
||||
if dtype.count > 1:
|
||||
return tuple([exec_alu(op, dtype.scalar(), [x[i] if isinstance(x, tuple) else x for x in operands]) for i in range(dtype.count)])
|
||||
return truncate.get(dtype, lambda x: x)(python_alu[op](*operands))
|
||||
|
||||
def uop_alu_resolve(u:UOp) -> sint:
|
||||
if u.op is UOps.CONST: return u.arg
|
||||
|
||||
Reference in New Issue
Block a user