early assert cyclic read [pr] (#7259)

* early assert cyclic read [pr]

* misc
This commit is contained in:
qazal
2024-10-24 11:51:12 +03:00
committed by GitHub
parent b56fab54ea
commit 93934c2160

View File

@@ -1,4 +1,4 @@
import sys, atexit, functools
import sys, atexit, functools, itertools
from collections import defaultdict, deque
from dataclasses import dataclass, field
from typing import Callable, Set, Tuple, List, Dict, Optional, DefaultDict, cast
@@ -191,13 +191,18 @@ def to_uop(buf:LazyBuffer, outputs:List[LazyBuffer], inputs:List[LazyBuffer], bu
if buf.metadata is not None: metadata[ret] = buf.metadata
return ret
def _lower_lazybuffer(outs:List[LazyBuffer], buf_uops:Dict[Buffer, UOp], var_vals:Dict[Variable, int]) -> LBScheduleItem:
def _lower_lazybuffer(outs:List[LazyBuffer], buf_uops:Dict[Buffer, UOp], uop_bufs:Dict[UOp, Buffer],
var_vals:Dict[Variable, int]) -> LBScheduleItem:
"""describe the computation for a LazyBuffer with UOp + inputs + var_vals"""
cache: Dict[LazyBuffer, UOp] = {}
inputs: List[LazyBuffer] = []
metadata: Dict[UOp, Metadata] = {}
sink = UOp(UOps.SINK, src=tuple(UOp.store(buf_uops[out.buffer], ShapeTracker.from_shape(out.shape).to_uop(),
to_uop(out, outs, inputs, buf_uops, metadata, cache)) for out in outs))
# assert cyclic dependency
for b,reads in itertools.groupby((x for x in sink.sparents if x.op in {UOps.PRELOAD, UOps.LOAD}), key=lambda x:x.src[0]):
if not all_same([x.op for x in reads]):
raise RuntimeError(f"cycle detected in kernel.\nhelp: consider using .contiguous() to load the pre-assign version of {uop_bufs[b]}.")
sink = full_ast_rewrite(sink, ctx:=ScheduleItemContext(var_vals))
# we also allow masked views. if it has a single view and it's equal when you shrink a contig, it's fine
if len(assign_targets:=[x.src[0] for x in sink.sparents if x.op is UOps.ASSIGN]) != 0:
@@ -398,7 +403,8 @@ def create_schedule_with_vars(outs:List[LazyBuffer]) -> Tuple[List[ScheduleItem]
if buf.op is not MetaOps.CONST:output_groups[reduce_for_op.get(buf, buf)].append(buf_uops[buf.buffer])
# preschedule all buffers in realizes
prescheduled = [_lower_lazybuffer([lazybufs_to_realize[uop_bufs[b]] for b in outs], buf_uops, var_vals) for outs in output_groups.values()]
prescheduled = [_lower_lazybuffer([lazybufs_to_realize[uop_bufs[b]] for b in outs], buf_uops, uop_bufs,
var_vals) for outs in output_groups.values()]
schedule_targets = {out:lsi for lsi in prescheduled for out in lsi.outputs}
graph: DefaultDict[LBScheduleItem, List[LBScheduleItem]] = defaultdict(list)