mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 23:18:04 -05:00
This reverts commit 540e4179e7.
This commit is contained in:
@@ -738,9 +738,9 @@ class TestTensorMetadata(unittest.TestCase):
|
||||
self.assertEqual(out.lazydata.srcs[1].metadata.name, "sigmoid")
|
||||
s = create_schedule([out.lazydata])
|
||||
self.assertEqual(len(s[-1].metadata), 3)
|
||||
self.assertEqual(s[-1].metadata[0].name, "__mul__")
|
||||
self.assertEqual(s[-1].metadata[1].name, "relu")
|
||||
self.assertEqual(s[-1].metadata[2].name, "sigmoid")
|
||||
self.assertEqual(s[-1].metadata[0].name, "relu")
|
||||
self.assertEqual(s[-1].metadata[1].name, "sigmoid")
|
||||
self.assertEqual(s[-1].metadata[2].name, "__mul__")
|
||||
|
||||
def test_complex_backward(self):
|
||||
_METADATA.set(None)
|
||||
@@ -757,7 +757,7 @@ class TestTensorMetadata(unittest.TestCase):
|
||||
self.assertEqual(len(s[-1].metadata), 3)
|
||||
self.assertEqual(s[-1].metadata[0].name, "sigmoid")
|
||||
self.assertEqual(s[-1].metadata[1].name, "sigmoid")
|
||||
self.assertTrue(s[-1].metadata[0].backward)
|
||||
self.assertTrue(s[-1].metadata[1].backward)
|
||||
self.assertEqual(s[-1].metadata[2].name, "relu")
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
@@ -41,15 +41,14 @@ class ScheduleItem:
|
||||
@dataclass(frozen=True)
|
||||
class ScheduleContext:
|
||||
realizes: Dict[Buffer, LazyBuffer]
|
||||
buf_metadata: Dict[UOp, Metadata] = field(default_factory=dict)
|
||||
buf_uops: Dict[Buffer, UOp] = field(default_factory=dict)
|
||||
uop_bufs: Dict[UOp, Buffer] = field(default_factory=dict)
|
||||
var_vals: Dict[Variable, int] = field(default_factory=dict)
|
||||
|
||||
def to_uop(buf:LazyBuffer, outputs:List[LazyBuffer], ctx:ScheduleContext, cache:Dict[LazyBuffer, UOp]) -> UOp:
|
||||
def to_uop(buf:LazyBuffer, outputs:List[LazyBuffer], ctx:ScheduleContext, metadata:Dict[UOp, Metadata], cache:Dict[LazyBuffer, UOp]) -> UOp:
|
||||
if (r:=cache.get(buf)) is not None: return r
|
||||
if buf is not buf.base:
|
||||
cache[buf] = ret = to_uop(buf.base, outputs, ctx, cache).view(buf.st)
|
||||
cache[buf] = ret = to_uop(buf.base, outputs, ctx, metadata, cache).view(buf.st)
|
||||
return ret
|
||||
dtype = buf.dtype.base if isinstance(buf.dtype, ImageDType) else buf.dtype
|
||||
# consts have VALID + value
|
||||
@@ -65,7 +64,7 @@ def to_uop(buf:LazyBuffer, outputs:List[LazyBuffer], ctx:ScheduleContext, cache:
|
||||
if buf.is_realized(): return UOp(UOps.PRELOAD, dtype, (ubuf, buf.st.to_uop()))
|
||||
if b in ctx.realizes and buf not in outputs: return UOp(UOps.LOAD, dtype, (ubuf, buf.st.to_uop()))
|
||||
# otherwise we fuse it like normal
|
||||
src = tuple(to_uop(x, outputs, ctx, cache) for x in buf.srcs)
|
||||
src = tuple(to_uop(x, outputs, ctx, metadata, cache) for x in buf.srcs)
|
||||
if buf.op in ReduceOps: ret = src[0].r(buf.op, buf.arg)
|
||||
elif buf.op is MetaOps.CONTIGUOUS: ret = UOp(UOps.CONTIGUOUS, dtype, src)
|
||||
elif buf.op is MetaOps.ASSIGN: ret = UOp(UOps.ASSIGN, dtype, (ubuf, src[1]), buf.arg)
|
||||
@@ -74,7 +73,7 @@ def to_uop(buf:LazyBuffer, outputs:List[LazyBuffer], ctx:ScheduleContext, cache:
|
||||
elif buf.op is UnaryOps.BITCAST: ret = UOp(UOps.BITCAST, dtype, src)
|
||||
else: ret = UOp(UOps.ALU, dtype, src, buf.op)
|
||||
cache[buf] = ret = UOp(UOps.LOAD, dtype, (ubuf, buf.st.to_uop(), UOp.store(ubuf, ShapeTracker.from_shape(buf.shape).to_uop(), ret)))
|
||||
if buf.metadata is not None: ctx.buf_metadata[ubuf] = buf.metadata
|
||||
if buf.metadata is not None: metadata[ubuf] = buf.metadata
|
||||
return ret
|
||||
|
||||
# **** AST graph rewrite
|
||||
@@ -163,7 +162,6 @@ view_right = merge_views+PatternMatcher([
|
||||
class ScheduleItemContext:
|
||||
var_vals: Dict[Variable, int]
|
||||
assigned: Set[UOp]
|
||||
buf_metadata: Dict[UOp, Metadata]
|
||||
sts: Set[ShapeTracker] = field(default_factory=set)
|
||||
bufs: List[UOp] = field(default_factory=list)
|
||||
assign_preloads: List[UOp] = field(default_factory=list)
|
||||
@@ -194,11 +192,8 @@ to_si = PatternMatcher([
|
||||
|
||||
# ** fusion
|
||||
|
||||
def _fold_load(ctx:ScheduleItemContext, b:UOp, v:UOp) -> UOp:
|
||||
if (m:=ctx.buf_metadata.get(b)) is not None: ctx.metadata[m] = None
|
||||
return v
|
||||
lazy = PatternMatcher([
|
||||
(UPat.load(b:=UPat.var("b"), UPat(), UPat.store(b, UPat(), UPat.var("v"))), _fold_load),
|
||||
(UPat.load(b:=UPat.var("b"), UPat(), UPat.store(b, UPat(), UPat.var("v"))), lambda b,v: v),
|
||||
])
|
||||
|
||||
multioutput = PatternMatcher([
|
||||
@@ -207,7 +202,7 @@ multioutput = PatternMatcher([
|
||||
|
||||
def full_ast_rewrite(pre:UOp, ctx:ScheduleItemContext) -> UOp:
|
||||
# fuse and fold store -> loads
|
||||
sink = graph_rewrite(pre, lazy, ctx)
|
||||
sink = graph_rewrite(pre, lazy)
|
||||
# fuse multi output
|
||||
if len(sink.src) > 1: sink = graph_rewrite(sink, multioutput, {x.src[0]:x.src[2] for x in sink.src})
|
||||
# assert cyclic dependency
|
||||
@@ -224,7 +219,7 @@ def full_ast_rewrite(pre:UOp, ctx:ScheduleItemContext) -> UOp:
|
||||
and ShapeTracker.from_shape(s.shape).shrink(m) == s.shrink(m)) for x in sink.sparents if x.op is UOps.LOAD and x.src[0] in assign_targets):
|
||||
raise RuntimeError("self operand of augmented assign must be contiguous.\nhelp: consider using .contiguous():\n"
|
||||
+colored(" - a += a.T\n", "red")+colored(" + a += a.T.contiguous()", "green"))
|
||||
if getenv("RUN_PROCESS_REPLAY"): PROCESS_REPLAY_CAPTURE.append((pre, ScheduleItemContext(ctx.var_vals, ctx.assigned, ctx.buf_metadata), sink))
|
||||
if getenv("RUN_PROCESS_REPLAY"): PROCESS_REPLAY_CAPTURE.append((pre, ScheduleItemContext(ctx.var_vals, ctx.assigned), sink))
|
||||
return sink
|
||||
|
||||
PROCESS_REPLAY_CAPTURE: List[Tuple[UOp, ScheduleItemContext, UOp]] = []
|
||||
@@ -244,10 +239,11 @@ def create_schedule_with_vars(outs:List[LazyBuffer]) -> Tuple[List[ScheduleItem]
|
||||
for stores in store_groups:
|
||||
outs = [lazybufs_to_realize[b] for b in stores]
|
||||
cache: Dict[LazyBuffer, UOp] = {}
|
||||
to_store = tuple(to_uop(out, outs, ctx, cache) for out in outs)
|
||||
metadata: Dict[UOp, Metadata] = {}
|
||||
to_store = tuple(to_uop(out, outs, ctx, metadata, cache) for out in outs)
|
||||
sink = UOp(UOps.SINK, src=tuple(UOp.store(ctx.buf_uops[x.buffer], ShapeTracker.from_shape(x.shape).to_uop(), u) for x,u in zip(outs,to_store)))
|
||||
si_ctx = ScheduleItemContext(ctx.var_vals, {ubuf for x in assigns if (ubuf:=ctx.buf_uops.get(x.buffer)) is not None},
|
||||
ctx.buf_metadata, metadata={metadata:None for x in to_store if (metadata:=ctx.buf_metadata.get(x.src[0]))})
|
||||
metadata={x:None for x in metadata.values()})
|
||||
small_graphs.append((full_ast_rewrite(sink, si_ctx), si_ctx))
|
||||
|
||||
# do BFS
|
||||
|
||||
Reference in New Issue
Block a user