mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
This reverts commit 8186e4e7d6.
This commit is contained in:
Binary file not shown.
@@ -7,7 +7,7 @@ from test.external.process_replay.process_replay import _run_differ
|
||||
PAGE_SIZE = 100
|
||||
RUN_ID = os.getenv("GITHUB_RUN_ID", "HEAD")
|
||||
TABLE_NAME = f"process_replay_{RUN_ID}_{getenv('GITHUB_RUN_ATTEMPT')}_{VERSION}"
|
||||
LOGOPS = os.getenv("LOGOPS", "/tmp/ops")
|
||||
LOGOPS = os.getenv("LOGOPS", "/tmp/sops")
|
||||
|
||||
def extract_ast(offset:int):
|
||||
logops = open(LOGOPS, "a")
|
||||
|
||||
@@ -64,8 +64,8 @@ def assert_equiv_uops(u1:UOp, u2:UOp) -> None:
|
||||
|
||||
def ast_const(dtype:DType, val:ConstType, shape:Tuple[sint, ...]=(), st:Optional[ShapeTracker]=None, st_src:Optional[Tuple[UOp]]=None) -> UOp:
|
||||
if st_src is None:
|
||||
st_src = st.to_uop() if st is not None else ShapeTracker.from_shape(()).reshape((1,)*len(shape)).expand(shape).to_uop(),
|
||||
return UOp(UOps.VALID, dtypes.bool, st_src).where(UOp.const(dtype, val), UOp.const(dtype, 0))
|
||||
st_src = (st.to_uop() if st is not None else ShapeTracker.from_shape(()).reshape((1,)*len(shape)).expand(shape).to_uop(),)
|
||||
return UOp(UOps.CONST, dtype, st_src, dtypes.as_const(val, dtype))
|
||||
|
||||
T = TypeVar("T")
|
||||
def timeit(fxn:Callable[..., T], *args, **kwargs) -> Tuple[T, float]:
|
||||
|
||||
@@ -572,13 +572,13 @@ class TestMultiTensor(unittest.TestCase):
|
||||
assert ast.op is UOps.STORE
|
||||
assert ast.src[2].arg is BinaryOps.ADD
|
||||
assert ast.src[2].src[0].op is UOps.LOAD and ast.src[2].src[0]
|
||||
assert ast.src[2].src[1].src[1].op is UOps.CONST and ast.src[2].src[1].src[1].arg == 1
|
||||
assert ast.src[2].src[1].op is UOps.CONST and ast.src[2].src[1].arg == 1
|
||||
t = 2 * t
|
||||
for si in t.schedule():
|
||||
ast = si.ast.src[0]
|
||||
assert ast.op is UOps.STORE
|
||||
assert ast.src[2].arg is BinaryOps.MUL
|
||||
assert ast.src[2].src[0].src[1].op is UOps.CONST and ast.src[2].src[0].src[1].arg == 2
|
||||
assert ast.src[2].src[0].op is UOps.CONST and ast.src[2].src[0].arg == 2
|
||||
assert ast.src[2].src[1].op is UOps.LOAD
|
||||
t = t + t.full_like(3)
|
||||
for si in t.schedule():
|
||||
@@ -586,7 +586,7 @@ class TestMultiTensor(unittest.TestCase):
|
||||
assert ast.op is UOps.STORE
|
||||
assert ast.src[2].arg is BinaryOps.ADD
|
||||
assert ast.src[2].src[0].op is UOps.LOAD
|
||||
assert ast.src[2].src[1].src[1].op is UOps.CONST and ast.src[2].src[1].src[1].arg == 3
|
||||
assert ast.src[2].src[1].op is UOps.CONST and ast.src[2].src[1].arg == 3
|
||||
|
||||
def test_shard_memory(self):
|
||||
devices = (d0, d1, d2, d3)
|
||||
|
||||
@@ -79,7 +79,7 @@ class TestVerifyAST(unittest.TestCase):
|
||||
uop_sts = verify_ast(a.schedule()[-1].ast)
|
||||
store_st = [st for u,st in uop_sts.items() if u.op is UOps.STORE][0]
|
||||
self.assertEqual(store_st, ShapeTracker.from_shape((4, 4)))
|
||||
const_st = [st for u,st in uop_sts.items() if u.op is UOps.VALID][0]
|
||||
const_st = [st for u,st in uop_sts.items() if u.op is UOps.CONST][0]
|
||||
self.assertEqual(const_st, ShapeTracker.from_shape((1, 1)).expand((4, 4)))
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
@@ -639,7 +639,7 @@ class Kernel:
|
||||
# for locals, we use the ShapeTracker that's in the srcs
|
||||
st = op.st_arg if op.src[0].op is UOps.DEFINE_LOCAL else self.sts[self.bufs.index(op)]
|
||||
st_uop = (st if apply_to_st is None else apply_to_st(st)).to_uop()
|
||||
if op.op is UOps.VALID: return replace(op, src=(st_uop,))
|
||||
if op.op is UOps.CONST: return replace(op, src=(st_uop,))
|
||||
if op.op is UOps.STORE: return replace(op, src=(op.src[0], st_uop, fixup_ast(op.src[2], apply_to_st)))
|
||||
return replace(op, src=(op.src[0], st_uop, *[fixup_ast(x, apply_to_st) for x in op.src[2:]]))
|
||||
if op.op is UOps.REDUCE_AXIS:
|
||||
@@ -777,8 +777,8 @@ class Kernel:
|
||||
def _assert_valid_uop(uop:UOp, st:ShapeTracker, sts:Dict[UOp, ShapeTracker]) -> None:
|
||||
if uop in sts: return
|
||||
op, _, src, arg = uop.op, uop.dtype, uop.src, uop.arg
|
||||
# NOTE: UOps.DEFINE_GLOBAL, UOps.DEFINE_LOCAL, UOps.CONST and UOps.DEFINE_VAR don't have shape
|
||||
if op in {UOps.DEFINE_LOCAL, UOps.DEFINE_GLOBAL, UOps.CONST, UOps.DEFINE_VAR}: return
|
||||
# NOTE: UOps.DEFINE_GLOBAL and UOps.DEFINE_LOCAL don't have shape
|
||||
if op in {UOps.DEFINE_LOCAL, UOps.DEFINE_GLOBAL}: return
|
||||
# restore globals from the two stage reduce
|
||||
if op is UOps.LOAD and src[0].op is UOps.DEFINE_LOCAL:
|
||||
_assert_valid_uop(local_reduce:=src[2].src[2], uop.st_arg, sts)
|
||||
@@ -791,8 +791,8 @@ def _assert_valid_uop(uop:UOp, st:ShapeTracker, sts:Dict[UOp, ShapeTracker]) ->
|
||||
assert op in {UOps.SHAPETRACKER, UOps.ALU, UOps.CAST, UOps.BITCAST, *BUFFER_UOPS}, f"bad UOp in intermediate uops {uop}"
|
||||
# movementops are pushed to the edges with SHAPETRACKER
|
||||
# elementwise inherits shape
|
||||
st = arg if op is UOps.SHAPETRACKER else sts[src[uop.st_loc if op in BUFFER_UOPS else 0]]
|
||||
for x in (src[0:1] if len(src) and src[0].op is UOps.VALID else src[1:] if op in BUFFER_UOPS else src):
|
||||
st = arg if op is UOps.SHAPETRACKER else sts[src[uop.st_loc if op in BUFFER_UOPS else -1]]
|
||||
for x in (src[1:] if op in BUFFER_UOPS else src):
|
||||
if sts[x].shape != st.shape:
|
||||
if prod(sts[x].shape) == prod(st.shape): raise AssertionError(f"found implicit reshape {x.op} {op} {sts[x].shape} != {st.shape}")
|
||||
raise AssertionError(f"found implicit expand {x.op} {sts[x].shape} != {op} {st.shape} {prod(sts[x].shape)} != {prod(st.shape)}")
|
||||
|
||||
@@ -91,9 +91,9 @@ class IndependentLowerer:
|
||||
def _to_uop(self, x:UOp) -> UOp:
|
||||
if x.op in BUFFER_UOPS:
|
||||
idx, valid = x.st_arg.to_indexed_uops(self.ridxs if x.op is UOps.LOAD and x.src[0].op is UOps.DEFINE_LOCAL else self.idxs)
|
||||
if x.op is UOps.VALID: return valid
|
||||
# TODO: check has_valid in UPat, not here
|
||||
has_valid = valid.op is not UOps.CONST or valid.arg is not True
|
||||
if x.op is UOps.CONST: return valid.where(x.const_like(x.arg), x.const_like(0))
|
||||
buf = x.src[0]
|
||||
if x.op is UOps.LOAD:
|
||||
barrier = (UOp(UOps.BARRIER, None, (self.to_uop(x.src[2]),)),) if x.src[0].op is UOps.DEFINE_LOCAL else ()
|
||||
|
||||
@@ -66,7 +66,7 @@ def _recursive_uop(buf:LazyBuffer, st:ShapeTracker, outputs:Tuple[LazyBuffer, ..
|
||||
val, var_val = val.unbind()
|
||||
var_vals[val] = var_val
|
||||
else: assert isinstance(val, get_args(ConstType)), f"cannot create ConstBuffer with value {val}"
|
||||
return UOp(UOps.VALID, dtypes.bool, (unbound_st.to_uop(),)).where(UOp.const(dtype, val), UOp.const(dtype, 0))
|
||||
return UOp(UOps.CONST, dtype, (unbound_st.to_uop(),), val)
|
||||
# otherwise, it's a load and we add it to the inputs
|
||||
if buf in assign_targets and not (unbound_st.contiguous or (len(unbound_st.views) == 1 and unbound_st.views[0].mask is not None and \
|
||||
ShapeTracker.from_shape(unbound_st.shape).shrink(unbound_st.views[0].mask) == unbound_st.shrink(unbound_st.views[0].mask))):
|
||||
@@ -136,9 +136,8 @@ def _recurse_reduceops(buf:LazyBuffer, st:ShapeTracker, realizes:Dict[LazyBuffer
|
||||
def get_output_st(uop:UOp, uop_sts:Dict[UOp, ShapeTracker]) -> Optional[ShapeTracker]:
|
||||
if (st:=uop_sts.get(uop)): return st
|
||||
if uop.op in BUFFER_UOPS: return uop.st_arg
|
||||
src = [x for x in uop.src if x.op not in {UOps.CONST, UOps.DEFINE_VAR}]
|
||||
src_sts = [xst for x in src if (xst:=get_output_st(x, uop_sts)) is not None]
|
||||
if len(src_sts) != len(src) or not all_same([x.shape for x in src_sts]): return None
|
||||
src_sts = [xst for x in uop.src if (xst:=get_output_st(x, uop_sts)) is not None]
|
||||
if len(src_sts) != len(uop.src) or not all_same([x.shape for x in src_sts]): return None
|
||||
uop_sts[uop] = st = ShapeTracker.from_shape(src_sts[0].reduce(uop.arg[1])) if uop.op is UOps.REDUCE_AXIS else src_sts[0]
|
||||
return st
|
||||
|
||||
|
||||
@@ -319,9 +319,8 @@ class UOps(Enum):
|
||||
# ops that are not graph nodes
|
||||
ENDRANGE = auto()
|
||||
ENDIF = auto()
|
||||
VALID = auto()
|
||||
|
||||
BUFFER_UOPS = {UOps.LOAD, UOps.STORE, UOps.VALID}
|
||||
BUFFER_UOPS = {UOps.LOAD, UOps.STORE, UOps.CONST}
|
||||
|
||||
END_FOR_UOP = {UOps.IF:(UOps.STORE, UOps.ENDIF), UOps.RANGE:(UOps.ASSIGN, UOps.ENDRANGE)}
|
||||
|
||||
@@ -349,7 +348,7 @@ class UOp(MathTrait):
|
||||
self.arg in {BinaryOps.ADD, BinaryOps.MUL, BinaryOps.MAX, BinaryOps.CMPNE, BinaryOps.XOR, BinaryOps.AND, BinaryOps.OR})
|
||||
# *** uop syntactic sugar
|
||||
@property
|
||||
def st_loc(self) -> int: return 0 if self.op is UOps.VALID else 1
|
||||
def st_loc(self) -> int: return 0 if self.op is UOps.CONST else 1
|
||||
@property
|
||||
def st_arg(self) -> ShapeTracker:
|
||||
assert self.op in BUFFER_UOPS, f"st_arg called on {self.op}"
|
||||
@@ -385,12 +384,11 @@ class UOp(MathTrait):
|
||||
def full_shape(self) -> Tuple[sint, ...]:
|
||||
if self.op is UOps.SHAPETRACKER: return self.arg.shape
|
||||
# NOTE: UOps.DEFINE_GLOBAL and UOps.DEFINE_LOCAL don't have shape
|
||||
return tuple(max(x) for x in zip(*[x.full_shape for x in self.src if x.op not in {UOps.DEFINE_GLOBAL, UOps.DEFINE_LOCAL, \
|
||||
UOps.CONST, UOps.DEFINE_VAR}]))
|
||||
return tuple(max(x) for x in zip(*[x.full_shape for x in self.src if x.op not in {UOps.DEFINE_GLOBAL, UOps.DEFINE_LOCAL}]))
|
||||
def vars(self) -> Set[UOp]: return set([x for x in self.sparents if x.op is UOps.DEFINE_VAR])
|
||||
def variables(self) -> List[Variable]:
|
||||
st_vars: List[Set[Variable]] = [x.st_arg.vars() for x in self.sparents if x.op in BUFFER_UOPS]
|
||||
return sorted(set.union(*st_vars, [Variable(x.arg[0], x.arg[1].arg, x.arg[2].arg) for x in self.vars()]), key=lambda v: v.expr)
|
||||
return sorted(set.union(*st_vars, [Variable(x.arg[0], x.arg[1], x.arg[2]) for x in self.vars()]), key=lambda v: v.expr)
|
||||
def const_factor(self) -> int:
|
||||
"""largest known int that divides self"""
|
||||
if self.op is UOps.CONST: return self.arg
|
||||
|
||||
Reference in New Issue
Block a user