mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
uops work from lowerer [run_process_replay] (#5279)
This commit is contained in:
@@ -55,6 +55,7 @@ class UOp:
|
||||
def __truediv__(self, x): return UOp.alu(BinaryOps.MUL, self, UOp.alu(UnaryOps.RECIP, ufix(self.dtype, x)))
|
||||
def __mod__(self, x): return UOp.alu(BinaryOps.MOD, self, ufix(self.dtype, x))
|
||||
def ne(self, x): return UOp.alu(BinaryOps.CMPNE, self, ufix(self.dtype, x))
|
||||
def eq(self, x): return -self.ne(x)
|
||||
def lt(self, x): return UOp.alu(BinaryOps.CMPLT, self, ufix(self.dtype, x))
|
||||
def ge(self, x): return -self.lt(x)
|
||||
def max(self, x): return UOp.alu(BinaryOps.MAX, self, x)
|
||||
@@ -79,6 +80,44 @@ class UOp:
|
||||
@property # parents with self
|
||||
def sparents(self) -> Set[UOp]: return set([self]).union(self.parents)
|
||||
def vars(self) -> Set[UOp]: return set([x for x in set.union(set([self]), self.parents) if x.op is UOps.DEFINE_VAR])
|
||||
def divides(self, v):
|
||||
if self.op is UOps.CONST:
|
||||
return self.arg%v == 0
|
||||
if self.op is UOps.ALU:
|
||||
if self.arg is BinaryOps.ADD: return all(x.divides(v) for x in self.src)
|
||||
if self.arg is BinaryOps.MUL: return any(x.divides(v) for x in self.src)
|
||||
return False # generic false if we aren't sure
|
||||
|
||||
def type_verify(uops):
|
||||
for u in uops:
|
||||
uop, arg, src, dtype = u.op, u.arg, u.src, u.dtype
|
||||
if uop in (UOps.CONST, UOps.DEFINE_ACC):
|
||||
if uop is UOps.DEFINE_ACC:
|
||||
assert dtype is not None and src[0].dtype == dtype.scalar(), f"type of {src[0].dtype=} must be a scalar {dtype.scalar()}"
|
||||
arg = src[0].arg
|
||||
assert dtype is not None and type(arg) is type(dtypes.as_const(arg, dtype)), f"type of {arg=} does not match {dtype}"
|
||||
if uop in {UOps.CAST, UOps.BITCAST}: assert arg is None # type is the output type, not an arg
|
||||
if uop is UOps.CAST and dtype is not None and dtype.count > 1: assert len(src) == dtype.count
|
||||
if uop is UOps.LOAD and len(src) > 3 and src[2].op is UOps.ALU: assert src[2].dtype == dtypes.bool and src[3].dtype == dtype
|
||||
if uop is UOps.STORE and len(src) == 4: assert src[3].dtype == dtypes.bool
|
||||
if uop is UOps.ALU:
|
||||
if arg in UnaryOps:
|
||||
assert dtype == src[0].dtype, f"{arg} dtype mismatch {dtype=} != {src[0].dtype=}"
|
||||
elif arg in (BinaryOps.CMPLT, BinaryOps.CMPNE):
|
||||
assert dtype == dtypes.bool, f"{arg} output dtype mismatch {dtype=} != {dtypes.bool}"
|
||||
assert src[0].dtype == src[1].dtype, f"{arg} dtype mismatch {dtype=} != {src[0].dtype=} != {src[1].dtype=}"
|
||||
elif arg is BinaryOps.IDIV:
|
||||
assert dtypes.is_int(src[0].dtype) and dtypes.is_int(src[1].dtype), \
|
||||
f"input dtype mismatch {dtypes.int} != {src[0].dtype=} != {src[1].dtype=}"
|
||||
assert dtypes.is_int(dtype), f"{arg} output dtype mismatch {dtype=} != {dtypes.int}"
|
||||
elif arg in {BinaryOps.SHL, BinaryOps.SHR}:
|
||||
# the distance to shift isn't typechecked
|
||||
assert dtype == src[0].dtype, f"{arg} dtype mismatch {dtype=} != {src[0].dtype=}"
|
||||
elif arg in BinaryOps:
|
||||
assert dtype == src[0].dtype == src[1].dtype, f"{arg} dtype mismatch {dtype=} != {src[0].dtype=} != {src[1].dtype=}"
|
||||
elif arg == TernaryOps.WHERE:
|
||||
assert src[0].dtype == dtypes.bool, f"{arg} selector dtype mismatch {src[0].dtype=} != {dtypes.bool}"
|
||||
assert dtype == src[1].dtype == src[2].dtype, f"{arg} choice dtype mismatch {dtype=} != {src[1].dtype=} != {src[2].dtype=}"
|
||||
|
||||
def uop_alu_resolve(u:UOp) -> sint:
|
||||
if u.op is UOps.CONST: return u.arg
|
||||
@@ -99,11 +138,13 @@ class UPat:
|
||||
name: Optional[str] = None
|
||||
dtype: Optional[Union[DType, Set[DType]]] = None
|
||||
allow_len: Set[int] = field(default_factory=set)
|
||||
allow_any_len: bool = False
|
||||
|
||||
@staticmethod
|
||||
def compile(u: UOp, name:Optional[str]=None) -> UPat:
|
||||
if u.op is UOps.VAR: return UPat(name=name or u.arg, dtype=u.dtype) if len(u.src) == 0 else UPat.compile(u.src[0], name or u.arg)
|
||||
return UPat(u.op, u.arg, (list if u.commutative() else tuple)([UPat.compile(src) for src in u.src]) if u.src != () else None, name, u.dtype)
|
||||
return UPat(u.op, u.arg, (list if u.commutative() else tuple)([UPat.compile(src) for src in u.src]) if u.src != () else None,
|
||||
name, u.dtype, allow_any_len=(isinstance(name, str) and 'allow_any_len' in name))
|
||||
|
||||
T = TypeVar("T")
|
||||
def __unmatch(m1:Union[T, Set[T]], m2:T) -> bool: return m2 not in m1 if isinstance(m1, set) else m2 != m1
|
||||
@@ -118,7 +159,7 @@ def _match(uop:UOp, pat:UPat, store:Dict[str, UOp]) -> bool:
|
||||
# try all permutations if it's a list
|
||||
# repeat if it's a UPat
|
||||
for vp in itertools.permutations(pat.src) if isinstance(pat.src,list) else ([pat.src] if isinstance(pat.src,tuple) else [(pat.src,)*len(uop.src)]):
|
||||
if len(uop.src) != len(vp) and (len(uop.src) not in pat.allow_len): return False
|
||||
if len(uop.src) != len(vp) and (len(uop.src) not in pat.allow_len) and not pat.allow_any_len: return False
|
||||
new_store = store.copy()
|
||||
if all(_match(uu, vv, new_store) for uu, vv in zip(uop.src, vp)):
|
||||
store.update(new_store)
|
||||
@@ -141,7 +182,7 @@ class PatternMatcher:
|
||||
def rewrite(self, uop:UOp) -> Optional[UOp]:
|
||||
for p,fxn in itertools.chain(self.pdict[(uop.op, uop.arg)], self.pdict[(uop.op, None)]):
|
||||
store: Dict[str, UOp] = {}
|
||||
if _match(uop, p, store): return fxn(**store)
|
||||
if _match(uop, p, store) and (ret:=fxn(**store)) is not None: return ret # NOTE: if it returns None, we keep trying to match
|
||||
return None
|
||||
|
||||
def sum_collapse(phi_input, loop, val1, val2):
|
||||
@@ -328,7 +369,7 @@ class UOpGraph:
|
||||
for i,u in enumerate(self):
|
||||
print(f"{i:4d} {str(u.op):20s}: {str(u.dtype) if u.dtype is not None else '':25s} " f"{str([self.uops.index(x) for x in u.src]):32s} {u.arg}")
|
||||
|
||||
def linearize(self, extra_pm:Optional[PatternMatcher]=None, type_verify=True):
|
||||
def linearize(self, extra_pm:Optional[PatternMatcher]=None, do_type_verify=True):
|
||||
# NOTE: relinearizering should be okay
|
||||
#assert self._uops is None, "already linearized"
|
||||
|
||||
@@ -398,7 +439,7 @@ class UOpGraph:
|
||||
assert self._uops[-1].op is UOps.SINK, f"didn't end with SINK, ended with {self._uops[-1]}"
|
||||
self._uops = self._uops[:-1]
|
||||
|
||||
if type_verify: self.type_verify()
|
||||
if do_type_verify: type_verify(self.uops)
|
||||
|
||||
# *** checker functions ***
|
||||
|
||||
@@ -434,33 +475,3 @@ class UOpGraph:
|
||||
assert u.arg[1] is not None
|
||||
flops += 2 * prod(u.arg[1]) // 32 * mults
|
||||
return flops, mem
|
||||
|
||||
def type_verify(self):
|
||||
for u in self.uops:
|
||||
uop, arg, src, dtype = u.op, u.arg, u.src, u.dtype
|
||||
if uop in (UOps.CONST, UOps.DEFINE_ACC):
|
||||
if uop is UOps.DEFINE_ACC:
|
||||
assert dtype is not None and src[0].dtype == dtype.scalar(), f"type of {src[0].dtype=} must be a scalar {dtype.scalar()}"
|
||||
arg = src[0].arg
|
||||
assert dtype is not None and type(arg) is type(dtypes.as_const(arg, dtype)), f"type of {arg=} does not match {dtype}"
|
||||
if uop in {UOps.CAST, UOps.BITCAST}: assert arg is None # type is the output type, not an arg
|
||||
if uop is UOps.LOAD and len(src) > 3 and src[2].op is UOps.ALU: assert src[2].dtype == dtypes.bool and src[3].dtype == dtype
|
||||
if uop is UOps.STORE and len(src) == 4: assert src[3].dtype == dtypes.bool
|
||||
if uop is UOps.ALU:
|
||||
if arg in UnaryOps:
|
||||
assert dtype == src[0].dtype, f"{arg} dtype mismatch {dtype=} != {src[0].dtype=}"
|
||||
elif arg in (BinaryOps.CMPLT, BinaryOps.CMPNE):
|
||||
assert dtype == dtypes.bool, f"{arg} output dtype mismatch {dtype=} != {dtypes.bool}"
|
||||
assert src[0].dtype == src[1].dtype, f"{arg} dtype mismatch {dtype=} != {src[0].dtype=} != {src[1].dtype=}"
|
||||
elif arg is BinaryOps.IDIV:
|
||||
assert dtypes.is_int(src[0].dtype) and dtypes.is_int(src[1].dtype), \
|
||||
f"input dtype mismatch {dtypes.int} != {src[0].dtype=} != {src[1].dtype=}"
|
||||
assert dtypes.is_int(dtype), f"{arg} output dtype mismatch {dtype=} != {dtypes.int}"
|
||||
elif arg in {BinaryOps.SHL, BinaryOps.SHR}:
|
||||
# the distance to shift isn't typechecked
|
||||
assert dtype == src[0].dtype, f"{arg} dtype mismatch {dtype=} != {src[0].dtype=}"
|
||||
elif arg in BinaryOps:
|
||||
assert dtype == src[0].dtype == src[1].dtype, f"{arg} dtype mismatch {dtype=} != {src[0].dtype=} != {src[1].dtype=}"
|
||||
elif arg == TernaryOps.WHERE:
|
||||
assert src[0].dtype == dtypes.bool, f"{arg} selector dtype mismatch {src[0].dtype=} != {dtypes.bool}"
|
||||
assert dtype == src[1].dtype == src[2].dtype, f"{arg} choice dtype mismatch {dtype=} != {src[1].dtype=} != {src[2].dtype=}"
|
||||
|
||||
Reference in New Issue
Block a user