uops work from lowerer [run_process_replay] (#5279)

This commit is contained in:
George Hotz
2024-07-03 09:40:00 -07:00
committed by GitHub
parent 622b7bd556
commit 16e3b8b013

View File

@@ -55,6 +55,7 @@ class UOp:
def __truediv__(self, x): return UOp.alu(BinaryOps.MUL, self, UOp.alu(UnaryOps.RECIP, ufix(self.dtype, x)))
def __mod__(self, x): return UOp.alu(BinaryOps.MOD, self, ufix(self.dtype, x))
def ne(self, x): return UOp.alu(BinaryOps.CMPNE, self, ufix(self.dtype, x))
def eq(self, x): return -self.ne(x)
def lt(self, x): return UOp.alu(BinaryOps.CMPLT, self, ufix(self.dtype, x))
def ge(self, x): return -self.lt(x)
def max(self, x): return UOp.alu(BinaryOps.MAX, self, x)
@@ -79,6 +80,44 @@ class UOp:
@property # parents with self
def sparents(self) -> Set[UOp]: return set([self]).union(self.parents)
def vars(self) -> Set[UOp]: return set([x for x in set.union(set([self]), self.parents) if x.op is UOps.DEFINE_VAR])
def divides(self, v):
if self.op is UOps.CONST:
return self.arg%v == 0
if self.op is UOps.ALU:
if self.arg is BinaryOps.ADD: return all(x.divides(v) for x in self.src)
if self.arg is BinaryOps.MUL: return any(x.divides(v) for x in self.src)
return False # generic false if we aren't sure
def type_verify(uops):
for u in uops:
uop, arg, src, dtype = u.op, u.arg, u.src, u.dtype
if uop in (UOps.CONST, UOps.DEFINE_ACC):
if uop is UOps.DEFINE_ACC:
assert dtype is not None and src[0].dtype == dtype.scalar(), f"type of {src[0].dtype=} must be a scalar {dtype.scalar()}"
arg = src[0].arg
assert dtype is not None and type(arg) is type(dtypes.as_const(arg, dtype)), f"type of {arg=} does not match {dtype}"
if uop in {UOps.CAST, UOps.BITCAST}: assert arg is None # type is the output type, not an arg
if uop is UOps.CAST and dtype is not None and dtype.count > 1: assert len(src) == dtype.count
if uop is UOps.LOAD and len(src) > 3 and src[2].op is UOps.ALU: assert src[2].dtype == dtypes.bool and src[3].dtype == dtype
if uop is UOps.STORE and len(src) == 4: assert src[3].dtype == dtypes.bool
if uop is UOps.ALU:
if arg in UnaryOps:
assert dtype == src[0].dtype, f"{arg} dtype mismatch {dtype=} != {src[0].dtype=}"
elif arg in (BinaryOps.CMPLT, BinaryOps.CMPNE):
assert dtype == dtypes.bool, f"{arg} output dtype mismatch {dtype=} != {dtypes.bool}"
assert src[0].dtype == src[1].dtype, f"{arg} dtype mismatch {dtype=} != {src[0].dtype=} != {src[1].dtype=}"
elif arg is BinaryOps.IDIV:
assert dtypes.is_int(src[0].dtype) and dtypes.is_int(src[1].dtype), \
f"input dtype mismatch {dtypes.int} != {src[0].dtype=} != {src[1].dtype=}"
assert dtypes.is_int(dtype), f"{arg} output dtype mismatch {dtype=} != {dtypes.int}"
elif arg in {BinaryOps.SHL, BinaryOps.SHR}:
# the distance to shift isn't typechecked
assert dtype == src[0].dtype, f"{arg} dtype mismatch {dtype=} != {src[0].dtype=}"
elif arg in BinaryOps:
assert dtype == src[0].dtype == src[1].dtype, f"{arg} dtype mismatch {dtype=} != {src[0].dtype=} != {src[1].dtype=}"
elif arg == TernaryOps.WHERE:
assert src[0].dtype == dtypes.bool, f"{arg} selector dtype mismatch {src[0].dtype=} != {dtypes.bool}"
assert dtype == src[1].dtype == src[2].dtype, f"{arg} choice dtype mismatch {dtype=} != {src[1].dtype=} != {src[2].dtype=}"
def uop_alu_resolve(u:UOp) -> sint:
if u.op is UOps.CONST: return u.arg
@@ -99,11 +138,13 @@ class UPat:
name: Optional[str] = None
dtype: Optional[Union[DType, Set[DType]]] = None
allow_len: Set[int] = field(default_factory=set)
allow_any_len: bool = False
@staticmethod
def compile(u: UOp, name:Optional[str]=None) -> UPat:
if u.op is UOps.VAR: return UPat(name=name or u.arg, dtype=u.dtype) if len(u.src) == 0 else UPat.compile(u.src[0], name or u.arg)
return UPat(u.op, u.arg, (list if u.commutative() else tuple)([UPat.compile(src) for src in u.src]) if u.src != () else None, name, u.dtype)
return UPat(u.op, u.arg, (list if u.commutative() else tuple)([UPat.compile(src) for src in u.src]) if u.src != () else None,
name, u.dtype, allow_any_len=(isinstance(name, str) and 'allow_any_len' in name))
T = TypeVar("T")
def __unmatch(m1:Union[T, Set[T]], m2:T) -> bool: return m2 not in m1 if isinstance(m1, set) else m2 != m1
@@ -118,7 +159,7 @@ def _match(uop:UOp, pat:UPat, store:Dict[str, UOp]) -> bool:
# try all permutations if it's a list
# repeat if it's a UPat
for vp in itertools.permutations(pat.src) if isinstance(pat.src,list) else ([pat.src] if isinstance(pat.src,tuple) else [(pat.src,)*len(uop.src)]):
if len(uop.src) != len(vp) and (len(uop.src) not in pat.allow_len): return False
if len(uop.src) != len(vp) and (len(uop.src) not in pat.allow_len) and not pat.allow_any_len: return False
new_store = store.copy()
if all(_match(uu, vv, new_store) for uu, vv in zip(uop.src, vp)):
store.update(new_store)
@@ -141,7 +182,7 @@ class PatternMatcher:
def rewrite(self, uop:UOp) -> Optional[UOp]:
for p,fxn in itertools.chain(self.pdict[(uop.op, uop.arg)], self.pdict[(uop.op, None)]):
store: Dict[str, UOp] = {}
if _match(uop, p, store): return fxn(**store)
if _match(uop, p, store) and (ret:=fxn(**store)) is not None: return ret # NOTE: if it returns None, we keep trying to match
return None
def sum_collapse(phi_input, loop, val1, val2):
@@ -328,7 +369,7 @@ class UOpGraph:
for i,u in enumerate(self):
print(f"{i:4d} {str(u.op):20s}: {str(u.dtype) if u.dtype is not None else '':25s} " f"{str([self.uops.index(x) for x in u.src]):32s} {u.arg}")
def linearize(self, extra_pm:Optional[PatternMatcher]=None, type_verify=True):
def linearize(self, extra_pm:Optional[PatternMatcher]=None, do_type_verify=True):
# NOTE: relinearizering should be okay
#assert self._uops is None, "already linearized"
@@ -398,7 +439,7 @@ class UOpGraph:
assert self._uops[-1].op is UOps.SINK, f"didn't end with SINK, ended with {self._uops[-1]}"
self._uops = self._uops[:-1]
if type_verify: self.type_verify()
if do_type_verify: type_verify(self.uops)
# *** checker functions ***
@@ -434,33 +475,3 @@ class UOpGraph:
assert u.arg[1] is not None
flops += 2 * prod(u.arg[1]) // 32 * mults
return flops, mem
def type_verify(self):
for u in self.uops:
uop, arg, src, dtype = u.op, u.arg, u.src, u.dtype
if uop in (UOps.CONST, UOps.DEFINE_ACC):
if uop is UOps.DEFINE_ACC:
assert dtype is not None and src[0].dtype == dtype.scalar(), f"type of {src[0].dtype=} must be a scalar {dtype.scalar()}"
arg = src[0].arg
assert dtype is not None and type(arg) is type(dtypes.as_const(arg, dtype)), f"type of {arg=} does not match {dtype}"
if uop in {UOps.CAST, UOps.BITCAST}: assert arg is None # type is the output type, not an arg
if uop is UOps.LOAD and len(src) > 3 and src[2].op is UOps.ALU: assert src[2].dtype == dtypes.bool and src[3].dtype == dtype
if uop is UOps.STORE and len(src) == 4: assert src[3].dtype == dtypes.bool
if uop is UOps.ALU:
if arg in UnaryOps:
assert dtype == src[0].dtype, f"{arg} dtype mismatch {dtype=} != {src[0].dtype=}"
elif arg in (BinaryOps.CMPLT, BinaryOps.CMPNE):
assert dtype == dtypes.bool, f"{arg} output dtype mismatch {dtype=} != {dtypes.bool}"
assert src[0].dtype == src[1].dtype, f"{arg} dtype mismatch {dtype=} != {src[0].dtype=} != {src[1].dtype=}"
elif arg is BinaryOps.IDIV:
assert dtypes.is_int(src[0].dtype) and dtypes.is_int(src[1].dtype), \
f"input dtype mismatch {dtypes.int} != {src[0].dtype=} != {src[1].dtype=}"
assert dtypes.is_int(dtype), f"{arg} output dtype mismatch {dtype=} != {dtypes.int}"
elif arg in {BinaryOps.SHL, BinaryOps.SHR}:
# the distance to shift isn't typechecked
assert dtype == src[0].dtype, f"{arg} dtype mismatch {dtype=} != {src[0].dtype=}"
elif arg in BinaryOps:
assert dtype == src[0].dtype == src[1].dtype, f"{arg} dtype mismatch {dtype=} != {src[0].dtype=} != {src[1].dtype=}"
elif arg == TernaryOps.WHERE:
assert src[0].dtype == dtypes.bool, f"{arg} selector dtype mismatch {src[0].dtype=} != {dtypes.bool}"
assert dtype == src[1].dtype == src[2].dtype, f"{arg} choice dtype mismatch {dtype=} != {src[1].dtype=} != {src[2].dtype=}"