diff symbolic with uops [run_process_replay] (#5841)

* diff symbolic with uops

* mergable symbolic diff
This commit is contained in:
George Hotz
2024-07-31 15:15:01 -07:00
committed by GitHub
parent 72621d9e7c
commit 5ff3e46718
2 changed files with 72 additions and 52 deletions

View File

@@ -22,45 +22,62 @@ render_ops: Any = { NumNode: lambda self, ops, ctx: UOp.const(dtypes.bigint, sel
SumNode: lambda self,ops,ctx: functools.reduce(lambda a,b: a+b.render(ops, ctx), self.nodes[1:], self.nodes[0].render(ops,ctx)),
AndNode: lambda self,ops,ctx: functools.reduce(lambda a,b: a*b.render(ops, ctx), self.nodes[1:], self.nodes[0].render(ops,ctx)) }
if getenv("UOP_IS_SYMBOLIC"):
# TODO: change this once UOps is ready to replace symbolic
def _uop_view(view:View, idxs:List[UOp], vexpr:UOp) -> Tuple[UOp, UOp]:
# TODO: dtypes.realint
iexpr = variable_to_uop(view.offset)
for idx,sh,st,m in zip(idxs, view.shape, view.strides, view.mask if view.mask is not None else [None]*len(view.shape)):
if sh != 1 and st != 0: iexpr = iexpr + idx*variable_to_uop(st)
if m is not None:
if m[0] != 0: vexpr = vexpr * idx.ge(variable_to_uop(m[0]))
if m[1] != sh: vexpr = vexpr * idx.lt(variable_to_uop(m[1]))
return iexpr, vexpr
def _uop_view(view:View, idxs:List[UOp], vexpr:UOp) -> Tuple[UOp, UOp]:
# TODO: dtypes.realint
iexpr = variable_to_uop(view.offset)
for idx,sh,st,m in zip(idxs, view.shape, view.strides, view.mask if view.mask is not None else [None]*len(view.shape)):
if sh != 1 and st != 0: iexpr = iexpr + idx*variable_to_uop(st)
if m is not None:
if m[0] != 0: vexpr = vexpr * idx.ge(variable_to_uop(m[0]))
if m[1] != sh: vexpr = vexpr * idx.lt(variable_to_uop(m[1]))
return iexpr, vexpr
def st_to_uops(st:ShapeTracker, idxs:List[UOp], dtype:DType) -> Tuple[UOp, UOp]:
idx, valid = _uop_view(st.views[-1], idxs, UOp.const(dtypes.bool, True))
for view in reversed(st.views[0:-1]):
view = view.minify()
acc, idxs = 1, []
for _d in reversed(view.shape):
d = variable_to_uop(_d)
idxs.append((idx//acc)%d)
acc *= d
idx, valid = _uop_view(view, idxs[::-1], valid)
if isinstance(dtype, ImageDType):
idx = UOp(UOps.VECTORIZE, dtypes.int.vec(3), ((idx // 4) % dtype.shape[1], (idx // (4 * dtype.shape[1])), idx % 4))
return idx, valid
else:
def st_to_uops(st:ShapeTracker, idxs:List[UOp], dtype:DType) -> Tuple[UOp, UOp]:
fake_idxs = [Variable(f"__idx{i}", 0, s-1) for i,s in enumerate(st.shape)]
idx, valid = st.expr_idxs(fake_idxs)
ctx = dict(zip(fake_idxs, idxs))
uvalid = valid.render(render_ops, ctx)
if isinstance(dtype, ImageDType):
image_idxs = (idx // 4) % dtype.shape[1], (idx // (4 * dtype.shape[1])), idx % 4
uidx = UOp(UOps.VECTORIZE, dtypes.int.vec(3), tuple(x.render(render_ops, ctx) for x in image_idxs))
else:
uidx = idx.render(render_ops, ctx)
if uvalid.op is UOps.CONST: uvalid = UOp.const(dtypes.bool, uvalid.arg)
assert uvalid.dtype == dtypes.bool
return uidx, uvalid
# TODO: change this once UOps is ready to replace symbolic
def st_to_uops_graph(st:ShapeTracker, idxs:List[UOp], dtype:DType) -> Tuple[UOp, UOp]:
idx, valid = _uop_view(st.views[-1], idxs, UOp.const(dtypes.bool, True))
for view in reversed(st.views[0:-1]):
view = view.minify()
acc, idxs = 1, []
for _d in reversed(view.shape):
d = variable_to_uop(_d)
idxs.append((idx//acc)%d)
acc *= d
idx, valid = _uop_view(view, idxs[::-1], valid)
if isinstance(dtype, ImageDType):
idx = UOp(UOps.VECTORIZE, dtypes.int.vec(3), ((idx // 4) % dtype.shape[1], (idx // (4 * dtype.shape[1])), idx % 4))
return idx, valid
# TODO: this is the old one, delete when ready
def st_to_uops_symbolic(st:ShapeTracker, idxs:List[UOp], dtype:DType) -> Tuple[UOp, UOp]:
fake_idxs = [Variable(f"__idx{i}", 0, s-1) for i,s in enumerate(st.shape)]
idx, valid = st.expr_idxs(fake_idxs)
ctx = dict(zip(fake_idxs, idxs))
uvalid = valid.render(render_ops, ctx)
if isinstance(dtype, ImageDType):
image_idxs = (idx // 4) % dtype.shape[1], (idx // (4 * dtype.shape[1])), idx % 4
uidx = UOp(UOps.VECTORIZE, dtypes.int.vec(3), tuple(x.render(render_ops, ctx) for x in image_idxs))
else:
uidx = idx.render(render_ops, ctx)
if uvalid.op is UOps.CONST: uvalid = UOp.const(dtypes.bool, uvalid.arg)
assert uvalid.dtype == dtypes.bool
return uidx, uvalid
def st_to_uops(st:ShapeTracker, idxs:List[UOp], dtype:DType) -> Tuple[UOp, UOp]:
if getenv("SYMBOLIC_DIFF"):
symbolic_idx, symbolic_valid = st_to_uops_symbolic(st, idxs, dtype)
graph_idx, graph_valid = st_to_uops_graph(st, idxs, dtype)
import ocdiff
from tinygrad.codegen.uopgraph import UOpGraph
from tinygrad.renderer.cstyle import OpenCLRenderer
def render(s1, s2):
glbl = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), arg=("idxs", True))
st = tuple(UOp(UOps.STORE, None, (glbl, UOp.const(dtypes.int, i), s)) for i,s in enumerate([s1,s2]))
return OpenCLRenderer().render("indexing", UOpGraph(UOp(UOps.SINK, None, st)).linearize(skip_check=True))
cmp_symbolic, cmp_graph = render(symbolic_idx, symbolic_valid), render(graph_idx, graph_valid)
if cmp_symbolic != cmp_graph: print(ocdiff.console_diff(f"SYMBOLIC {len(cmp_symbolic)}\n"+cmp_symbolic, f"GRAPH {len(cmp_graph)}\n"+cmp_graph))
return st_to_uops_graph(st, idxs, dtype) if getenv("UOP_IS_SYMBOLIC") else st_to_uops_symbolic(st, idxs, dtype)
def _limit_dims(dims:Tuple[sint, ...], max_sizes:Tuple[int, ...]):
# TODO: symbolic shape

View File

@@ -485,7 +485,7 @@ class UOpGraph:
print(f"{i:4d} {str(u.op):20s}: {str(u.dtype) if u.dtype is not None else '':25s} " f"{str(formatted_parents):32s} {u.arg}")
cnt = 0
def linearize(self, extra_pm:Optional[PatternMatcher]=None):
def linearize(self, extra_pm:Optional[PatternMatcher]=None, skip_check=False) -> UOpGraph:
global acc_number
acc_number = 0
@@ -551,19 +551,20 @@ class UOpGraph:
for u, x in scope_end.items(): self._uops.insert(self._uops.index(x)+1, UOp(END_FOR_UOP[u.op][1], None, (u,)))
# sanity checks (NOTE: these can cause things to be skipped in BEAM)
bad_ops = dedup([x.op for x in self._uops if x.op in {UOps.EXPAND, UOps.CONTRACT, UOps.REDUCE}])
try:
type_verify(self.uops)
assert self._uops[-1].op is UOps.SINK, f"didn't end with SINK, ended with {self._uops[-1]}"
assert len(bad_ops) == 0, f"bad UOps left in list: {bad_ops}"
# TODO: this should be enabled, and the valid clause should be removed
# NOTE: multiple identical stores to DEFINE_LOCAL is okay
assert len(all_stores := [x.src[0:2]+x.src[3:] for x in self._uops if x.op is UOps.STORE and x.src[0].op is not UOps.DEFINE_LOCAL]) \
== len(dedup(all_stores)), "repeated stores in uops"
except AssertionError as e:
self.print()
if not CI: self.graph()
raise e
if not skip_check:
bad_ops = dedup([x.op for x in self._uops if x.op in {UOps.EXPAND, UOps.CONTRACT, UOps.REDUCE}])
try:
type_verify(self.uops)
assert self._uops[-1].op is UOps.SINK, f"didn't end with SINK, ended with {self._uops[-1]}"
assert len(bad_ops) == 0, f"bad UOps left in list: {bad_ops}"
# TODO: this should be enabled, and the valid clause should be removed
# NOTE: multiple identical stores to DEFINE_LOCAL is okay
assert len(all_stores := [x.src[0:2]+x.src[3:] for x in self._uops if x.op is UOps.STORE and x.src[0].op is not UOps.DEFINE_LOCAL]) \
== len(dedup(all_stores)), "repeated stores in uops"
except AssertionError as e:
self.print()
if not CI: self.graph()
raise e
# strip the SINK
self._uops = self._uops[:-1]
@@ -571,3 +572,5 @@ class UOpGraph:
if getenv("FUZZ_UOPS"):
from test.external.fuzz_uops import fuzz_uops
self._fuzz_paths = fuzz_uops(self)
return self