clean up ops file [pr] (#7013)

This commit is contained in:
George Hotz
2024-10-12 19:53:52 +08:00
committed by GitHub
parent 746a1f8c86
commit cba4b9a058

View File

@@ -153,7 +153,6 @@ COMMUTATIVE = {BinaryOps.ADD, BinaryOps.MUL, BinaryOps.MAX, BinaryOps.CMPNE, Bin
END_FOR_UOP = {UOps.IF:(UOps.STORE, UOps.ENDIF), UOps.RANGE:(UOps.ASSIGN, UOps.ENDRANGE)}
# With True as the default, this matches the old symbolic behavior
# python3 -c 'from tinygrad.ops import Variable; print(bool(Variable("a", 1, 10) < 4))' -> True
def resolve(x, default:bool=True):
if not isinstance(x, UOp): return bool(x)
assert x.dtype is dtypes.bool, "UOp in resolve must be bool"
@@ -195,6 +194,18 @@ class UOp(MathTrait):
new_args = (kwargs.get("op", self.op), kwargs.get("dtype", self.dtype), kwargs.get("src", self.src), kwargs.get("arg", self.arg))
if (self.op, self.dtype, self.src, self.arg) == new_args: return self
return UOp(*new_args)
@functools.cached_property
def key(self) -> bytes:
return hashlib.sha256(str((self.op, self.dtype, self.arg)).encode() + b"".join([s.key for s in self.src])).digest()
def __repr__(self): return pretty_print(self, lambda x: f"{type(self).__name__}({x.op}, {x.dtype}, arg={x.argstr()}, src=(%s))")
def argstr(self): return f'({", ".join(map(str, self.arg))})' if self.op is UOps.REDUCE_AXIS else self.arg
@functools.cached_property
def parents(self) -> Dict[UOp, None]: return {**{x:None for x in self.src}, **{k:None for x in self.src for k in x.parents}}
@property # parents with self
def sparents(self) -> Dict[UOp, None]: return {**self.parents, self:None}
# *** uop shape stuff ***
@property
def has_st(self) -> bool: return self.op not in {UOps.DEFINE_LOCAL, UOps.DEFINE_GLOBAL, UOps.BUFFER, UOps.CONST, UOps.DEFINE_VAR}
@functools.cached_property
@@ -207,11 +218,11 @@ class UOp(MathTrait):
from tinygrad.shape.shapetracker import ShapeTracker
return ShapeTracker.from_shape(src_sts[0].reduce(self.axis_arg)) if self.op is UOps.REDUCE_AXIS else src_sts[0]
@functools.cached_property
def key(self) -> bytes:
return hashlib.sha256(str((self.op, self.dtype, self.arg)).encode() + b"".join([s.key for s in self.src])).digest()
def __repr__(self): return pretty_print(self, lambda x: f"{type(self).__name__}({x.op}, {x.dtype}, arg={x.argstr()}, src=(%s))")
def argstr(self): return f'({", ".join(map(str, self.arg))})' if self.op is UOps.REDUCE_AXIS else self.arg
def full_shape(self) -> Tuple[sint, ...]:
return self.arg.shape if self.op is UOps.VIEW else tuple(smax(x) for x in zip(*[x.full_shape for x in self.src if x.has_st]))
# *** uop evaluation ***
def simplify(self):
with Context(TRACK_MATCH_STATS=0):
return graph_rewrite(self, symbolic)
@@ -226,7 +237,10 @@ class UOp(MathTrait):
def __bool__(self): return self._eval((dtypes.bool,), bool)
def __int__(self): return self._eval(dtypes.ints, int)
def __float__(self): return self._eval(dtypes.floats, float)
# *** uop syntactic sugar
def substitute(self, dvars:Dict[UOp, UOp]): return graph_rewrite(self, _substitute, dvars)
# *** uop syntactic sugar ***
@property
def st_arg(self) -> ShapeTracker:
assert self.op in BUFFER_UOPS, f"st_arg called on {self.op}"
@@ -276,8 +290,13 @@ class UOp(MathTrait):
if isinstance(b, UOp): return b.unbind()[0] if b.op is UOps.BIND else b
if isinstance(b, tuple) and all_same(b): b = b[0] # doesn't have to be a VCONST if they are all the same
return UOp(UOps.VCONST if isinstance(b, tuple) else UOps.CONST, dtype, arg=dtypes.as_const(b, dtype) if dtype is not None else b) # type: ignore
@staticmethod
def range(dtype:DType, start:ConstType|UOp, end:ConstType|UOp, idx:int):
return UOp(UOps.RANGE, dtype=dtype, src=(UOp.const(dtype, start) if not isinstance(start, UOp) else start,
UOp.const(dtype, end) if not isinstance(end, UOp) else end), arg=idx)
def reduce(self, op:BinaryOps, *rng:UOp): return UOp(UOps.REDUCE, self.dtype, (self,) + rng, op)
# *** Variable stuff ***
# *** uop Variable stuff ***
@staticmethod
def variable(name:str, min_val:ConstType, max_val:ConstType): return UOp(UOps.DEFINE_VAR, dtypes.int, arg=(name, min_val, max_val))
@@ -305,20 +324,8 @@ class UOp(MathTrait):
st_vars: List[Set[Variable]] = [x.st_arg.vars() for x in self.sparents if x.op in BUFFER_UOPS]
return sorted(set.union(*st_vars, [x.unbind()[0] if x.op is not UOps.DEFINE_VAR else x for x in self.vars()]), key=lambda v: v.arg)
def substitute(self, dvars:Dict[UOp, UOp]): return graph_rewrite(self, substitute, dvars)
# *** uop symbolic stuff ***
@staticmethod
def range(dtype:DType, start:ConstType|UOp, end:ConstType|UOp, idx:int):
return UOp(UOps.RANGE, dtype=dtype, src=(UOp.const(dtype, start) if not isinstance(start, UOp) else start,
UOp.const(dtype, end) if not isinstance(end, UOp) else end), arg=idx)
def reduce(self, op:BinaryOps, *rng:UOp): return UOp(UOps.REDUCE, self.dtype, (self,) + rng, op)
@functools.cached_property
def parents(self) -> Dict[UOp, None]: return {**{x:None for x in self.src}, **{k:None for x in self.src for k in x.parents}}
@property # parents with self
def sparents(self) -> Dict[UOp, None]: return {**self.parents, self:None}
@functools.cached_property
def full_shape(self) -> Tuple[sint, ...]:
return self.arg.shape if self.op is UOps.VIEW else tuple(smax(x) for x in zip(*[x.full_shape for x in self.src if x.has_st]))
def const_factor(self) -> int:
"""largest known int that divides self"""
if self.op is UOps.CONST: return self.arg
@@ -413,13 +420,6 @@ def exec_alu(op:Op, dtype:DType, operands):
return tuple([exec_alu(op, dtype.scalar(), [x[i] if isinstance(x, tuple) else x for x in operands]) for i in range(dtype.count)])
return truncate.get(dtype, lambda x: x)(python_alu[op](*operands))
def uop_alu_resolve(u:UOp) -> sint:
if u.op is UOps.CONST: return u.arg
if u.op is UOps.DEFINE_VAR: return u
#if u.op is UOps.DEFINE_VAR: return Variable(u.arg[0], u.arg[1], u.arg[2])
if u.op is UOps.ALU: return exec_alu(u.arg, u.dtype, tuple(map(uop_alu_resolve, u.src)))
raise RuntimeError(f"ALU resolve fail @ {u.op}")
# ***** uop helpers *****
def print_uops(uops:List[UOp]):
@@ -446,7 +446,7 @@ def flops_mem(uops:List[UOp], ignore_indexing=False) -> Tuple[sint, sint]:
for u in uops:
if u.op is UOps.RANGE:
mult_stack.append(mults)
mults *= uop_alu_resolve(u.src[1] - u.src[0])
mults *= (u.src[1] - u.src[0]).ssimplify()
elif u.op is UOps.ENDRANGE:
mults = mult_stack.pop(-1)
elif u.op is UOps.SPECIAL:
@@ -964,7 +964,7 @@ symbolic_flat = symbolic+PatternMatcher([
((UPat.var("x", dtypes.ints) + UPat.var("y")) * UPat.cvar("c"), lambda x,y,c: x*c+y*c),
])
substitute = PatternMatcher([(UPat(tuple(UOps), name="x"), lambda ctx,x: ctx.get(x,None))])
_substitute = PatternMatcher([(UPat(tuple(UOps), name="x"), lambda ctx,x: ctx.get(x,None))])
# for debug
renderer = PatternMatcher([