mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
clean up ops file [pr] (#7013)
This commit is contained in:
@@ -153,7 +153,6 @@ COMMUTATIVE = {BinaryOps.ADD, BinaryOps.MUL, BinaryOps.MAX, BinaryOps.CMPNE, Bin
|
||||
END_FOR_UOP = {UOps.IF:(UOps.STORE, UOps.ENDIF), UOps.RANGE:(UOps.ASSIGN, UOps.ENDRANGE)}
|
||||
|
||||
# With True as the default, this matches the old symbolic behavior
|
||||
# python3 -c 'from tinygrad.ops import Variable; print(bool(Variable("a", 1, 10) < 4))' -> True
|
||||
def resolve(x, default:bool=True):
|
||||
if not isinstance(x, UOp): return bool(x)
|
||||
assert x.dtype is dtypes.bool, "UOp in resolve must be bool"
|
||||
@@ -195,6 +194,18 @@ class UOp(MathTrait):
|
||||
new_args = (kwargs.get("op", self.op), kwargs.get("dtype", self.dtype), kwargs.get("src", self.src), kwargs.get("arg", self.arg))
|
||||
if (self.op, self.dtype, self.src, self.arg) == new_args: return self
|
||||
return UOp(*new_args)
|
||||
@functools.cached_property
|
||||
def key(self) -> bytes:
|
||||
return hashlib.sha256(str((self.op, self.dtype, self.arg)).encode() + b"".join([s.key for s in self.src])).digest()
|
||||
def __repr__(self): return pretty_print(self, lambda x: f"{type(self).__name__}({x.op}, {x.dtype}, arg={x.argstr()}, src=(%s))")
|
||||
def argstr(self): return f'({", ".join(map(str, self.arg))})' if self.op is UOps.REDUCE_AXIS else self.arg
|
||||
@functools.cached_property
|
||||
def parents(self) -> Dict[UOp, None]: return {**{x:None for x in self.src}, **{k:None for x in self.src for k in x.parents}}
|
||||
@property # parents with self
|
||||
def sparents(self) -> Dict[UOp, None]: return {**self.parents, self:None}
|
||||
|
||||
# *** uop shape stuff ***
|
||||
|
||||
@property
|
||||
def has_st(self) -> bool: return self.op not in {UOps.DEFINE_LOCAL, UOps.DEFINE_GLOBAL, UOps.BUFFER, UOps.CONST, UOps.DEFINE_VAR}
|
||||
@functools.cached_property
|
||||
@@ -207,11 +218,11 @@ class UOp(MathTrait):
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
return ShapeTracker.from_shape(src_sts[0].reduce(self.axis_arg)) if self.op is UOps.REDUCE_AXIS else src_sts[0]
|
||||
@functools.cached_property
|
||||
def key(self) -> bytes:
|
||||
return hashlib.sha256(str((self.op, self.dtype, self.arg)).encode() + b"".join([s.key for s in self.src])).digest()
|
||||
def __repr__(self): return pretty_print(self, lambda x: f"{type(self).__name__}({x.op}, {x.dtype}, arg={x.argstr()}, src=(%s))")
|
||||
def argstr(self): return f'({", ".join(map(str, self.arg))})' if self.op is UOps.REDUCE_AXIS else self.arg
|
||||
def full_shape(self) -> Tuple[sint, ...]:
|
||||
return self.arg.shape if self.op is UOps.VIEW else tuple(smax(x) for x in zip(*[x.full_shape for x in self.src if x.has_st]))
|
||||
|
||||
# *** uop evaluation ***
|
||||
|
||||
def simplify(self):
|
||||
with Context(TRACK_MATCH_STATS=0):
|
||||
return graph_rewrite(self, symbolic)
|
||||
@@ -226,7 +237,10 @@ class UOp(MathTrait):
|
||||
def __bool__(self): return self._eval((dtypes.bool,), bool)
|
||||
def __int__(self): return self._eval(dtypes.ints, int)
|
||||
def __float__(self): return self._eval(dtypes.floats, float)
|
||||
# *** uop syntactic sugar
|
||||
def substitute(self, dvars:Dict[UOp, UOp]): return graph_rewrite(self, _substitute, dvars)
|
||||
|
||||
# *** uop syntactic sugar ***
|
||||
|
||||
@property
|
||||
def st_arg(self) -> ShapeTracker:
|
||||
assert self.op in BUFFER_UOPS, f"st_arg called on {self.op}"
|
||||
@@ -276,8 +290,13 @@ class UOp(MathTrait):
|
||||
if isinstance(b, UOp): return b.unbind()[0] if b.op is UOps.BIND else b
|
||||
if isinstance(b, tuple) and all_same(b): b = b[0] # doesn't have to be a VCONST if they are all the same
|
||||
return UOp(UOps.VCONST if isinstance(b, tuple) else UOps.CONST, dtype, arg=dtypes.as_const(b, dtype) if dtype is not None else b) # type: ignore
|
||||
@staticmethod
|
||||
def range(dtype:DType, start:ConstType|UOp, end:ConstType|UOp, idx:int):
|
||||
return UOp(UOps.RANGE, dtype=dtype, src=(UOp.const(dtype, start) if not isinstance(start, UOp) else start,
|
||||
UOp.const(dtype, end) if not isinstance(end, UOp) else end), arg=idx)
|
||||
def reduce(self, op:BinaryOps, *rng:UOp): return UOp(UOps.REDUCE, self.dtype, (self,) + rng, op)
|
||||
|
||||
# *** Variable stuff ***
|
||||
# *** uop Variable stuff ***
|
||||
|
||||
@staticmethod
|
||||
def variable(name:str, min_val:ConstType, max_val:ConstType): return UOp(UOps.DEFINE_VAR, dtypes.int, arg=(name, min_val, max_val))
|
||||
@@ -305,20 +324,8 @@ class UOp(MathTrait):
|
||||
st_vars: List[Set[Variable]] = [x.st_arg.vars() for x in self.sparents if x.op in BUFFER_UOPS]
|
||||
return sorted(set.union(*st_vars, [x.unbind()[0] if x.op is not UOps.DEFINE_VAR else x for x in self.vars()]), key=lambda v: v.arg)
|
||||
|
||||
def substitute(self, dvars:Dict[UOp, UOp]): return graph_rewrite(self, substitute, dvars)
|
||||
# *** uop symbolic stuff ***
|
||||
|
||||
@staticmethod
|
||||
def range(dtype:DType, start:ConstType|UOp, end:ConstType|UOp, idx:int):
|
||||
return UOp(UOps.RANGE, dtype=dtype, src=(UOp.const(dtype, start) if not isinstance(start, UOp) else start,
|
||||
UOp.const(dtype, end) if not isinstance(end, UOp) else end), arg=idx)
|
||||
def reduce(self, op:BinaryOps, *rng:UOp): return UOp(UOps.REDUCE, self.dtype, (self,) + rng, op)
|
||||
@functools.cached_property
|
||||
def parents(self) -> Dict[UOp, None]: return {**{x:None for x in self.src}, **{k:None for x in self.src for k in x.parents}}
|
||||
@property # parents with self
|
||||
def sparents(self) -> Dict[UOp, None]: return {**self.parents, self:None}
|
||||
@functools.cached_property
|
||||
def full_shape(self) -> Tuple[sint, ...]:
|
||||
return self.arg.shape if self.op is UOps.VIEW else tuple(smax(x) for x in zip(*[x.full_shape for x in self.src if x.has_st]))
|
||||
def const_factor(self) -> int:
|
||||
"""largest known int that divides self"""
|
||||
if self.op is UOps.CONST: return self.arg
|
||||
@@ -413,13 +420,6 @@ def exec_alu(op:Op, dtype:DType, operands):
|
||||
return tuple([exec_alu(op, dtype.scalar(), [x[i] if isinstance(x, tuple) else x for x in operands]) for i in range(dtype.count)])
|
||||
return truncate.get(dtype, lambda x: x)(python_alu[op](*operands))
|
||||
|
||||
def uop_alu_resolve(u:UOp) -> sint:
|
||||
if u.op is UOps.CONST: return u.arg
|
||||
if u.op is UOps.DEFINE_VAR: return u
|
||||
#if u.op is UOps.DEFINE_VAR: return Variable(u.arg[0], u.arg[1], u.arg[2])
|
||||
if u.op is UOps.ALU: return exec_alu(u.arg, u.dtype, tuple(map(uop_alu_resolve, u.src)))
|
||||
raise RuntimeError(f"ALU resolve fail @ {u.op}")
|
||||
|
||||
# ***** uop helpers *****
|
||||
|
||||
def print_uops(uops:List[UOp]):
|
||||
@@ -446,7 +446,7 @@ def flops_mem(uops:List[UOp], ignore_indexing=False) -> Tuple[sint, sint]:
|
||||
for u in uops:
|
||||
if u.op is UOps.RANGE:
|
||||
mult_stack.append(mults)
|
||||
mults *= uop_alu_resolve(u.src[1] - u.src[0])
|
||||
mults *= (u.src[1] - u.src[0]).ssimplify()
|
||||
elif u.op is UOps.ENDRANGE:
|
||||
mults = mult_stack.pop(-1)
|
||||
elif u.op is UOps.SPECIAL:
|
||||
@@ -964,7 +964,7 @@ symbolic_flat = symbolic+PatternMatcher([
|
||||
((UPat.var("x", dtypes.ints) + UPat.var("y")) * UPat.cvar("c"), lambda x,y,c: x*c+y*c),
|
||||
])
|
||||
|
||||
substitute = PatternMatcher([(UPat(tuple(UOps), name="x"), lambda ctx,x: ctx.get(x,None))])
|
||||
_substitute = PatternMatcher([(UPat(tuple(UOps), name="x"), lambda ctx,x: ctx.get(x,None))])
|
||||
|
||||
# for debug
|
||||
renderer = PatternMatcher([
|
||||
|
||||
Reference in New Issue
Block a user