mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-23 22:08:08 -05:00
fix
This commit is contained in:
@@ -439,10 +439,12 @@ def realize_view(ctx:ScheduleContext, view:UOp, src:UOp, b:UOp, **kwargs) -> Non
|
||||
# otherwise safety check pads
|
||||
return None if (all(v.mask is None for v in st.views) or can_pad(src, ctx.realizes, set())) else realize(ctx, b, src)
|
||||
|
||||
def fold_img_cast(ctx:ScheduleContext, xb:UOp, view:UOp, b:UOp, to_cast:UOp, **kwargs) -> UOp|None:
|
||||
if not isinstance(xb.dtype, ImageDType) or b not in ctx.realizes or xb not in ctx.realizes or uval(to_cast).op in GroupOp.Meta: return None
|
||||
def fold_img_cast(ctx:ScheduleContext, b:UOp, base:UOp, root:UOp, src:UOp, view:UOp):
|
||||
if not isinstance(b.dtype, ImageDType) or not isinstance(src.buf_uop.dtype, ImageDType): return None
|
||||
if b not in ctx.realizes: return None
|
||||
if not src.is_realized and src.base.op is Ops.COPY: return None
|
||||
del ctx.realizes[b]
|
||||
return to_cast.view(unwrap(view.st))
|
||||
return src.view(unwrap(view.st))
|
||||
|
||||
def sink_outputs(ctx:ScheduleContext, sink:UOp) -> None:
|
||||
for x in sink.src: realize(ctx, x.buf_uop, x)
|
||||
@@ -461,7 +463,7 @@ do_realize = PatternMatcher([
|
||||
# realize before expand or unsafe pad ops
|
||||
(UPatScheduled(name="src").view(name="view"), realize_view),
|
||||
# don't realize image to image casts
|
||||
(UPatScheduled(Ops.CAST, src=(UPat(Ops.VIEW, src=(UPat.var("xb"), UPat()), name="to_cast"),), dtype=dtypes.float).view(name="view"), fold_img_cast),
|
||||
(UPat(Ops.VIEW, name="view", src=(UPatScheduled(Ops.CAST, name="root", src=(UPat.var("src"),)),)), fold_img_cast),
|
||||
# realize before COPY or BUFFER_VIEW
|
||||
(UPat(Ops.COPY, src=(UPat(), UPat.any(UPatScheduled(), UPatScheduled().view()),)), realize),
|
||||
(UPat(Ops.BUFFER_VIEW, src=(UPat.any(UPatScheduled(), UPatScheduled().view()),)), realize),
|
||||
|
||||
Reference in New Issue
Block a user